diff --git a/pygadmin/csv_importer.py b/pygadmin/csv_importer.py index aebff16..b70041d 100644 --- a/pygadmin/csv_importer.py +++ b/pygadmin/csv_importer.py @@ -14,28 +14,23 @@ class CSVImporter: assumed data types and inserting all the data. """ - def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", create_table=True, - table_name=None): - - # Use the given database connection for further execution of database queries on the given database. - self.database_connection = database_connection + def __init__(self, database_connection, csv_file, delimiter=",", null_type="NULL", table_name=None): # Use the csv file for loading the relevant data. self.csv_file = csv_file # Use a delimiter for the csv file, not necessarily ",". self.delimiter = delimiter # Define the null type of the csv file. self.null_type = null_type - # Use a new table for saving the data and not only appending to an existing one. - self.create_table = create_table # Get the name of the table. self.table_name = table_name # Save the csv data in a list. self.csv_data = [] # Save the data types in a list. self.data_types = [] - # TODO: Create functions for connecting the result signal and the error signal # Use the database query executor for executing the create table and insert queries. self.database_query_executor = DatabaseQueryExecutor() + # Use the given database connection for further execution of database queries on the given database. + self.database_query_executor.database_connection = database_connection self.database_query_executor.error.connect(self.print_error) self.database_query_executor.result_data.connect(self.print_result) @@ -113,81 +108,120 @@ class CSVImporter: self.data_types[check_column] = data_type def get_data_type(self, value): + """ + Get the data type of a specific value in a readable format as Postgres data type. Every value can be a text in + the end. + """ + + # If the value is the predefined null value/type, return a NULL value. Every other data type is still possible. if value == self.null_type: return "NULL" + # Try to cast the value to a float or an integer to check for a number. try: + # Try to float the value and check for an integer. if float(value).is_integer(): return "INT" - float(value) + # Only a float/REAL is a possible option now. return "REAL" + # Ignore the value error, because it is not relevant in this case and allowed to go silently. except ValueError: pass + # Return TEXT, if a match could not be made. return "TEXT" def create_table_for_csv_data(self): - if self.create_table is not True: - return + """ + Create the table to store the csv data in the database. + """ + # Get the create statement of the table. create_statement = self.get_create_statement() - with self.database_connection.cursor() as database_cursor: - database_cursor.execute(sql.SQL(create_statement)) + # Assign the create statement as query to the table. + self.database_query_executor.database_query = create_statement + # Execute! + self.database_query_executor.submit_and_execute_query() - if database_cursor.description: - print(database_cursor.description) + def get_create_statement(self, check_ddl=True): + """ + Build the CREATE statement for the table for inserting the csv data. The option "check_ddl" secures the column + name against possible dangerous names. The default is True, because such checks after "blind" imports are a + security feature. + """ - def get_create_statement(self): + # Get the table name, so the table name can be used in the create statement. self.get_table_name() + # Add the table name to the query. create_table_query = "CREATE TABLE {} (".format(self.table_name) + # Get the header as start of the csv data, because the columns are defined here. header = self.csv_data[0] + # Iterate over the header. The column count is necessary to set the comma value at the correct point, because + # the column to create does not need a comma at the end. for column_count in range(len(header)): + # Define the comma to set after the definition of the column. if column_count != len(header)-1: comma_value = "," else: comma_value = "" - create_column = "{} {}{}\n".format(header[column_count], self.data_types[column_count], comma_value) + # Get the current column. + current_column = header[column_count] + + # If the name of the column should be checked, check it. + if check_ddl: + current_column = self.check_ddl_parameter(current_column) + + # Define the current column with its name, its data type and its comma value. + create_column = "{} {}{}\n".format(current_column, self.data_types[column_count], comma_value) + # Build the table query by adding the new column. create_table_query = "{}{}".format(create_table_query, create_column) + # End the query, so the query is in an executable format. create_table_query = "{});".format(create_table_query) - print(create_table_query) - - # TODO: Mechanism for the user to check the create table statement - return create_table_query def get_table_name(self): + """ + Get the name of the table based on the name of the csv file, if the name is not specified by the user. + """ + if self.table_name is None: + # Split the csv file in the different parts of the path. slash_split_list = self.csv_file.split("/") + # Get the last part of the list as file name, because the last part of the path is the file identifier. self.table_name = slash_split_list[len(slash_split_list) - 1] + # Get the .csv out of the name by splitting. csv_split_list = self.table_name.split(".csv") + # Use the part without csv as table name. self.table_name = csv_split_list[0] + # Check the name of the table. self.table_name = self.check_ddl_parameter(self.table_name) @staticmethod def check_ddl_parameter(parameter): - parameter = re.sub(r"[^a-zA-Z0-9 _\.]", "", parameter) + """ + Check the given data definition language parameter for potentially malicious characters. Those malicious + characters could cause an SQL injection, so nearly all special characters are kicked out. + """ + # Define a matching regular expression for allowed characters. Kick every other character. Allowed special + # characters are _, whitespace, . and some special (german) characters, because they can not do any harm. + parameter = re.sub(r"[^a-zA-ZäöüÄÖÜß0-9 _\.]", "", parameter) + + # If a whitespace is part of the parameter, use " " around the parameter to make a valid column ddl parameter. if " " in parameter: parameter = '"{}"'.format(parameter) return parameter - def print_error(self, error): - print(error) - - def print_result(self, result): - print(result) - - def do_all_the_stuff(self): """ Normal persons would call this function "main". This function is only a placeholder to remember me, that I'm @@ -209,3 +243,4 @@ if __name__ == "__main__": csv_importer.parse_csv_file() csv_importer.assume_data_types() csv_importer.get_create_statement() +