Refactoring and a little bit of documentation

This commit is contained in:
Lea Laux 2021-02-15 14:58:01 +01:00 committed by KDV Admin
parent 71f64dc70e
commit 7e37ac06b7

View File

@ -1,25 +1,49 @@
import os import os
import csv import csv
import re
from psycopg2 import sql
from pygadmin.database_query_executor import DatabaseQueryExecutor from pygadmin.database_query_executor import DatabaseQueryExecutor
from pygadmin.connectionfactory import global_connection_factory
class CSVImporter: class CSVImporter:
"""
Create a class for importing (small) .csv files. The process involves parsing a csv file, creating the table with
assumed data types and inserting all the data.
"""
def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", create_table=True, def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", create_table=True,
table_name=None): table_name=None):
# Use the given database connection for further execution of database queries on the given database.
self.database_connection = database_connection self.database_connection = database_connection
# Use the csv file for loading the relevant data.
self.csv_file = csv_file self.csv_file = csv_file
# Use a delimiter for the csv file, not necessarily ",".
self.delimiter = delimiter self.delimiter = delimiter
# Define the null type of the csv file.
self.null_type = null_type self.null_type = null_type
# Use a new table for saving the data and not only appending to an existing one.
self.create_table = create_table self.create_table = create_table
# Get the name of the table.
self.table_name = table_name self.table_name = table_name
# Save the csv data in a list.
self.csv_data = [] self.csv_data = []
# Save the data types in a list.
self.data_types = [] self.data_types = []
# TODO: Create functions for connecting the result signal and the error signal # TODO: Create functions for connecting the result signal and the error signal
# Use the database query executor for executing the create table and insert queries.
self.database_query_executor = DatabaseQueryExecutor() self.database_query_executor = DatabaseQueryExecutor()
self.database_query_executor.error.connect(self.print_error)
self.database_query_executor.result_data.connect(self.print_result)
def check_existence_csv_file(self): def check_existence_csv_file(self):
"""
Check for the existence of the given csv file, because only if the file exists, data can be read from the file.
"""
if os.path.exists(self.csv_file): if os.path.exists(self.csv_file):
return True return True
@ -27,37 +51,64 @@ class CSVImporter:
return False return False
def parse_csv_file(self): def parse_csv_file(self):
try: """
with open(self.csv_file) as csv_file: Parse the content of the csv file.
reader = csv.reader(csv_file, delimiter=",") """
# Use a try in case of invalid permissions or a broken file.
try:
# Open the file.
with open(self.csv_file) as csv_file:
# Read the content of the file with the given delimiter.
reader = csv.reader(csv_file, delimiter=self.delimiter)
# Add every row to a data list.
for row in reader: for row in reader:
self.csv_data.append(row) self.csv_data.append(row)
# Return the success.
return True return True
except Exception as file_error: except Exception as file_error:
# Return the error.
return file_error return file_error
def assume_data_types(self): def assume_data_types(self):
"""
Assume the data types of the rows in the csv file based on the given values. Check the first 100 values, so the
overhead is small, but the check data is large enough to get a correct assumption. The supported data types for
assuming are NULL, INT, DECIMAL and TEXT.
"""
# If the data is larger than 100 rows, define the check limit for the first 100 rows.
if len(self.csv_data)-2 > 100: if len(self.csv_data)-2 > 100:
check_limit = 100 check_limit = 100
# Define the limit based on the file length.
else: else:
check_limit = len(self.csv_data) - 2 check_limit = len(self.csv_data) - 2
# Create a list for the data types.
self.data_types = [None] * len(self.csv_data[0]) self.data_types = [None] * len(self.csv_data[0])
# Check every row within the check limit.
for check_row in range(1, check_limit): for check_row in range(1, check_limit):
# Get the row.
current_row = self.csv_data[check_row] current_row = self.csv_data[check_row]
# Check the data type of the current column.
for check_column in range(len(current_row)): for check_column in range(len(current_row)):
# If the data type is TEXT, break, because there is nothing to change. This data type works in every
# case.
if self.data_types[check_column] == "TEXT": if self.data_types[check_column] == "TEXT":
break break
# Get the current value.
value = current_row[check_column] value = current_row[check_column]
# Get the data type of the current value.
data_type = self.get_data_type(value) data_type = self.get_data_type(value)
# If the data type is not null, write the data type in the data type list.
if data_type != "NULL": if data_type != "NULL":
self.data_types[check_column] = data_type self.data_types[check_column] = data_type
@ -66,19 +117,15 @@ class CSVImporter:
return "NULL" return "NULL"
try: try:
if float(value).is_integer():
return "INT"
float(value) float(value)
return "REAL" return "REAL"
except ValueError: except ValueError:
pass pass
try:
int(value)
return "INT"
except ValueError:
pass
return "TEXT" return "TEXT"
def create_table_for_csv_data(self): def create_table_for_csv_data(self):
@ -87,8 +134,15 @@ class CSVImporter:
create_statement = self.get_create_statement() create_statement = self.get_create_statement()
with self.database_connection.cursor() as database_cursor:
database_cursor.execute(sql.SQL(create_statement))
if database_cursor.description:
print(database_cursor.description)
def get_create_statement(self): def get_create_statement(self):
create_table_query = "CREATE TABLE %s (" self.get_table_name()
create_table_query = "CREATE TABLE {} (".format(self.table_name)
header = self.csv_data[0] header = self.csv_data[0]
@ -99,31 +153,40 @@ class CSVImporter:
else: else:
comma_value = "" comma_value = ""
create_column = '"%s" %s{}\n'.format(comma_value) create_column = "{} {}{}\n".format(header[column_count], self.data_types[column_count], comma_value)
create_table_query = "{}{}".format(create_table_query, create_column) create_table_query = "{}{}".format(create_table_query, create_column)
create_table_query = "{});".format(create_table_query) create_table_query = "{});".format(create_table_query)
print(create_table_query)
# TODO: Mechanism for the user to check the create table statement
return create_table_query return create_table_query
def get_create_parameter(self): def get_table_name(self):
parameter_list = []
if self.table_name is None: if self.table_name is None:
slash_split_list = self.csv_file.split("/") slash_split_list = self.csv_file.split("/")
self.table_name = slash_split_list[len(slash_split_list) - 1] self.table_name = slash_split_list[len(slash_split_list) - 1]
csv_split_list = self.table_name.split(".csv") csv_split_list = self.table_name.split(".csv")
self.table_name = csv_split_list[0] self.table_name = csv_split_list[0]
self.table_name = self.check_ddl_parameter(self.table_name)
parameter_list.append(self.table_name) @staticmethod
def check_ddl_parameter(parameter):
parameter = re.sub(r"[^a-zA-Z0-9 _\.]", "", parameter)
header = self.csv_data[0] if " " in parameter:
parameter = '"{}"'.format(parameter)
for column_count in range(len(header)): return parameter
parameter_list.append(header[column_count])
parameter_list.append(self.data_types[column_count]) def print_error(self, error):
print(error)
def print_result(self, result):
print(result)
return parameter_list
def do_all_the_stuff(self): def do_all_the_stuff(self):
""" """
@ -140,10 +203,9 @@ class CSVImporter:
# create table or insert in existing table, (create table and) insert # create table or insert in existing table, (create table and) insert
if __name__ == "__main__": if __name__ == "__main__":
csv_importer = CSVImporter(None, "/home/lal45210/test.csv") csv_importer = CSVImporter(global_connection_factory.get_database_connection("localhost", "testuser", "testdb"),
csv_importer.check_existence_csv_file() "/home/sqlea/test.csv", delimiter=";", table_name="new_test_table")
if csv_importer.check_existence_csv_file() is True:
csv_importer.parse_csv_file() csv_importer.parse_csv_file()
csv_importer.assume_data_types() csv_importer.assume_data_types()
print(csv_importer.get_create_parameter()) csv_importer.get_create_statement()