Refactoring and documentation of the csv importer

State of the art: CREATE TABLE with check of the given columns for
mallicious content
This commit is contained in:
Lea Laux 2021-02-16 10:06:04 +01:00 committed by KDV Admin
parent 7e37ac06b7
commit ec53065798

View File

@ -14,28 +14,23 @@ class CSVImporter:
assumed data types and inserting all the data. assumed data types and inserting all the data.
""" """
def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", create_table=True, def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", table_name=None):
table_name=None):
# Use the given database connection for further execution of database queries on the given database.
self.database_connection = database_connection
# Use the csv file for loading the relevant data. # Use the csv file for loading the relevant data.
self.csv_file = csv_file self.csv_file = csv_file
# Use a delimiter for the csv file, not necessarily ",". # Use a delimiter for the csv file, not necessarily ",".
self.delimiter = delimiter self.delimiter = delimiter
# Define the null type of the csv file. # Define the null type of the csv file.
self.null_type = null_type self.null_type = null_type
# Use a new table for saving the data and not only appending to an existing one.
self.create_table = create_table
# Get the name of the table. # Get the name of the table.
self.table_name = table_name self.table_name = table_name
# Save the csv data in a list. # Save the csv data in a list.
self.csv_data = [] self.csv_data = []
# Save the data types in a list. # Save the data types in a list.
self.data_types = [] self.data_types = []
# TODO: Create functions for connecting the result signal and the error signal
# Use the database query executor for executing the create table and insert queries. # Use the database query executor for executing the create table and insert queries.
self.database_query_executor = DatabaseQueryExecutor() self.database_query_executor = DatabaseQueryExecutor()
# Use the given database connection for further execution of database queries on the given database.
self.database_query_executor.database_connection = database_connection
self.database_query_executor.error.connect(self.print_error) self.database_query_executor.error.connect(self.print_error)
self.database_query_executor.result_data.connect(self.print_result) self.database_query_executor.result_data.connect(self.print_result)
@ -113,81 +108,120 @@ class CSVImporter:
self.data_types[check_column] = data_type self.data_types[check_column] = data_type
def get_data_type(self, value): def get_data_type(self, value):
"""
Get the data type of a specific value in a readable format as Postgres data type. Every value can be a text in
the end.
"""
# If the value is the predefined null value/type, return a NULL value. Every other data type is still possible.
if value == self.null_type: if value == self.null_type:
return "NULL" return "NULL"
# Try to cast the value to a float or an integer to check for a number.
try: try:
# Try to float the value and check for an integer.
if float(value).is_integer(): if float(value).is_integer():
return "INT" return "INT"
float(value) # Only a float/REAL is a possible option now.
return "REAL" return "REAL"
# Ignore the value error, because it is not relevant in this case and allowed to go silently.
except ValueError: except ValueError:
pass pass
# Return TEXT, if a match could not be made.
return "TEXT" return "TEXT"
def create_table_for_csv_data(self): def create_table_for_csv_data(self):
if self.create_table is not True: """
return Create the table to store the csv data in the database.
"""
# Get the create statement of the table.
create_statement = self.get_create_statement() create_statement = self.get_create_statement()
with self.database_connection.cursor() as database_cursor: # Assign the create statement as query to the table.
database_cursor.execute(sql.SQL(create_statement)) self.database_query_executor.database_query = create_statement
# Execute!
self.database_query_executor.submit_and_execute_query()
if database_cursor.description: def get_create_statement(self, check_ddl=True):
print(database_cursor.description) """
Build the CREATE statement for the table for inserting the csv data. The option "check_ddl" secures the column
name against possible dangerous names. The default is True, because such checks after "blind" imports are a
security feature.
"""
def get_create_statement(self): # Get the table name, so the table name can be used in the create statement.
self.get_table_name() self.get_table_name()
# Add the table name to the query.
create_table_query = "CREATE TABLE {} (".format(self.table_name) create_table_query = "CREATE TABLE {} (".format(self.table_name)
# Get the header as start of the csv data, because the columns are defined here.
header = self.csv_data[0] header = self.csv_data[0]
# Iterate over the header. The column count is necessary to set the comma value at the correct point, because
# the column to create does not need a comma at the end.
for column_count in range(len(header)): for column_count in range(len(header)):
# Define the comma to set after the definition of the column.
if column_count != len(header)-1: if column_count != len(header)-1:
comma_value = "," comma_value = ","
else: else:
comma_value = "" comma_value = ""
create_column = "{} {}{}\n".format(header[column_count], self.data_types[column_count], comma_value) # Get the current column.
current_column = header[column_count]
# If the name of the column should be checked, check it.
if check_ddl:
current_column = self.check_ddl_parameter(current_column)
# Define the current column with its name, its data type and its comma value.
create_column = "{} {}{}\n".format(current_column, self.data_types[column_count], comma_value)
# Build the table query by adding the new column.
create_table_query = "{}{}".format(create_table_query, create_column) create_table_query = "{}{}".format(create_table_query, create_column)
# End the query, so the query is in an executable format.
create_table_query = "{});".format(create_table_query) create_table_query = "{});".format(create_table_query)
print(create_table_query)
# TODO: Mechanism for the user to check the create table statement
return create_table_query return create_table_query
def get_table_name(self): def get_table_name(self):
"""
Get the name of the table based on the name of the csv file, if the name is not specified by the user.
"""
if self.table_name is None: if self.table_name is None:
# Split the csv file in the different parts of the path.
slash_split_list = self.csv_file.split("/") slash_split_list = self.csv_file.split("/")
# Get the last part of the list as file name, because the last part of the path is the file identifier.
self.table_name = slash_split_list[len(slash_split_list) - 1] self.table_name = slash_split_list[len(slash_split_list) - 1]
# Get the .csv out of the name by splitting.
csv_split_list = self.table_name.split(".csv") csv_split_list = self.table_name.split(".csv")
# Use the part without csv as table name.
self.table_name = csv_split_list[0] self.table_name = csv_split_list[0]
# Check the name of the table.
self.table_name = self.check_ddl_parameter(self.table_name) self.table_name = self.check_ddl_parameter(self.table_name)
@staticmethod @staticmethod
def check_ddl_parameter(parameter): def check_ddl_parameter(parameter):
parameter = re.sub(r"[^a-zA-Z0-9 _\.]", "", parameter) """
Check the given data definition language parameter for potentially malicious characters. Those malicious
characters could cause an SQL injection, so nearly all special characters are kicked out.
"""
# Define a matching regular expression for allowed characters. Kick every other character. Allowed special
# characters are _, whitespace, . and some special (german) characters, because they can not do any harm.
parameter = re.sub(r"[^a-zA-ZäöüÄÖÜß0-9 _\.]", "", parameter)
# If a whitespace is part of the parameter, use " " around the parameter to make a valid column ddl parameter.
if " " in parameter: if " " in parameter:
parameter = '"{}"'.format(parameter) parameter = '"{}"'.format(parameter)
return parameter return parameter
def print_error(self, error):
print(error)
def print_result(self, result):
print(result)
def do_all_the_stuff(self): def do_all_the_stuff(self):
""" """
Normal persons would call this function "main". This function is only a placeholder to remember me, that I'm Normal persons would call this function "main". This function is only a placeholder to remember me, that I'm
@ -209,3 +243,4 @@ if __name__ == "__main__":
csv_importer.parse_csv_file() csv_importer.parse_csv_file()
csv_importer.assume_data_types() csv_importer.assume_data_types()
csv_importer.get_create_statement() csv_importer.get_create_statement()