Refactoring and a little bit of documentation

This commit is contained in:
Lea Laux 2021-02-15 14:58:01 +01:00 committed by KDV Admin
parent 71f64dc70e
commit 7e37ac06b7

View File

@ -1,25 +1,49 @@
import os
import csv
import re
from psycopg2 import sql
from pygadmin.database_query_executor import DatabaseQueryExecutor
from pygadmin.connectionfactory import global_connection_factory
class CSVImporter:
"""
Create a class for importing (small) .csv files. The process involves parsing a csv file, creating the table with
assumed data types and inserting all the data.
"""
def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", create_table=True,
table_name=None):
# Use the given database connection for further execution of database queries on the given database.
self.database_connection = database_connection
# Use the csv file for loading the relevant data.
self.csv_file = csv_file
# Use a delimiter for the csv file, not necessarily ",".
self.delimiter = delimiter
# Define the null type of the csv file.
self.null_type = null_type
# Use a new table for saving the data and not only appending to an existing one.
self.create_table = create_table
# Get the name of the table.
self.table_name = table_name
# Save the csv data in a list.
self.csv_data = []
# Save the data types in a list.
self.data_types = []
# TODO: Create functions for connecting the result signal and the error signal
# Use the database query executor for executing the create table and insert queries.
self.database_query_executor = DatabaseQueryExecutor()
self.database_query_executor.error.connect(self.print_error)
self.database_query_executor.result_data.connect(self.print_result)
def check_existence_csv_file(self):
"""
Check for the existence of the given csv file, because only if the file exists, data can be read from the file.
"""
if os.path.exists(self.csv_file):
return True
@ -27,37 +51,64 @@ class CSVImporter:
return False
def parse_csv_file(self):
try:
with open(self.csv_file) as csv_file:
reader = csv.reader(csv_file, delimiter=",")
"""
Parse the content of the csv file.
"""
# Use a try in case of invalid permissions or a broken file.
try:
# Open the file.
with open(self.csv_file) as csv_file:
# Read the content of the file with the given delimiter.
reader = csv.reader(csv_file, delimiter=self.delimiter)
# Add every row to a data list.
for row in reader:
self.csv_data.append(row)
# Return the success.
return True
except Exception as file_error:
# Return the error.
return file_error
def assume_data_types(self):
"""
Assume the data types of the rows in the csv file based on the given values. Check the first 100 values, so the
overhead is small, but the check data is large enough to get a correct assumption. The supported data types for
assuming are NULL, INT, DECIMAL and TEXT.
"""
# If the data is larger than 100 rows, define the check limit for the first 100 rows.
if len(self.csv_data)-2 > 100:
check_limit = 100
# Define the limit based on the file length.
else:
check_limit = len(self.csv_data) - 2
# Create a list for the data types.
self.data_types = [None] * len(self.csv_data[0])
# Check every row within the check limit.
for check_row in range(1, check_limit):
# Get the row.
current_row = self.csv_data[check_row]
# Check the data type of the current column.
for check_column in range(len(current_row)):
# If the data type is TEXT, break, because there is nothing to change. This data type works in every
# case.
if self.data_types[check_column] == "TEXT":
break
# Get the current value.
value = current_row[check_column]
# Get the data type of the current value.
data_type = self.get_data_type(value)
# If the data type is not null, write the data type in the data type list.
if data_type != "NULL":
self.data_types[check_column] = data_type
@ -66,19 +117,15 @@ class CSVImporter:
return "NULL"
try:
if float(value).is_integer():
return "INT"
float(value)
return "REAL"
except ValueError:
pass
try:
int(value)
return "INT"
except ValueError:
pass
return "TEXT"
def create_table_for_csv_data(self):
@ -87,8 +134,15 @@ class CSVImporter:
create_statement = self.get_create_statement()
with self.database_connection.cursor() as database_cursor:
database_cursor.execute(sql.SQL(create_statement))
if database_cursor.description:
print(database_cursor.description)
def get_create_statement(self):
create_table_query = "CREATE TABLE %s ("
self.get_table_name()
create_table_query = "CREATE TABLE {} (".format(self.table_name)
header = self.csv_data[0]
@ -99,31 +153,40 @@ class CSVImporter:
else:
comma_value = ""
create_column = '"%s" %s{}\n'.format(comma_value)
create_column = "{} {}{}\n".format(header[column_count], self.data_types[column_count], comma_value)
create_table_query = "{}{}".format(create_table_query, create_column)
create_table_query = "{});".format(create_table_query)
print(create_table_query)
# TODO: Mechanism for the user to check the create table statement
return create_table_query
def get_create_parameter(self):
parameter_list = []
def get_table_name(self):
if self.table_name is None:
slash_split_list = self.csv_file.split("/")
self.table_name = slash_split_list[len(slash_split_list) - 1]
csv_split_list = self.table_name.split(".csv")
self.table_name = csv_split_list[0]
self.table_name = self.check_ddl_parameter(self.table_name)
parameter_list.append(self.table_name)
@staticmethod
def check_ddl_parameter(parameter):
parameter = re.sub(r"[^a-zA-Z0-9 _\.]", "", parameter)
header = self.csv_data[0]
if " " in parameter:
parameter = '"{}"'.format(parameter)
for column_count in range(len(header)):
parameter_list.append(header[column_count])
parameter_list.append(self.data_types[column_count])
return parameter
def print_error(self, error):
print(error)
def print_result(self, result):
print(result)
return parameter_list
def do_all_the_stuff(self):
"""
@ -140,10 +203,9 @@ class CSVImporter:
# create table or insert in existing table, (create table and) insert
if __name__ == "__main__":
csv_importer = CSVImporter(None, "/home/lal45210/test.csv")
csv_importer.check_existence_csv_file()
csv_importer.parse_csv_file()
csv_importer.assume_data_types()
print(csv_importer.get_create_parameter())
csv_importer = CSVImporter(global_connection_factory.get_database_connection("localhost", "testuser", "testdb"),
"/home/sqlea/test.csv", delimiter=";", table_name="new_test_table")
if csv_importer.check_existence_csv_file() is True:
csv_importer.parse_csv_file()
csv_importer.assume_data_types()
csv_importer.get_create_statement()