From f506bcbafe67addff696861a88d87846cc18c962 Mon Sep 17 00:00:00 2001 From: Lea Laux Date: Tue, 16 Feb 2021 14:15:51 +0100 Subject: [PATCH] Insert function for inserting the data --- pygadmin/csv_importer.py | 108 ++++++++++++++++++++++++++------- pygadmin/widgets/csv_import.py | 83 +++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 22 deletions(-) create mode 100644 pygadmin/widgets/csv_import.py diff --git a/pygadmin/csv_importer.py b/pygadmin/csv_importer.py index b70041d..db2f634 100644 --- a/pygadmin/csv_importer.py +++ b/pygadmin/csv_importer.py @@ -1,9 +1,8 @@ +import copy import os import csv import re -from psycopg2 import sql - from pygadmin.database_query_executor import DatabaseQueryExecutor from pygadmin.connectionfactory import global_connection_factory @@ -31,8 +30,6 @@ class CSVImporter: self.database_query_executor = DatabaseQueryExecutor() # Use the given database connection for further execution of database queries on the given database. self.database_query_executor.database_connection = database_connection - self.database_query_executor.error.connect(self.print_error) - self.database_query_executor.result_data.connect(self.print_result) def check_existence_csv_file(self): """ @@ -75,13 +72,7 @@ class CSVImporter: assuming are NULL, INT, DECIMAL and TEXT. """ - # If the data is larger than 100 rows, define the check limit for the first 100 rows. - if len(self.csv_data)-2 > 100: - check_limit = 100 - - # Define the limit based on the file length. - else: - check_limit = len(self.csv_data) - 2 + check_limit = len(self.csv_data) - 2 # Create a list for the data types. self.data_types = [None] * len(self.csv_data[0]) @@ -95,17 +86,16 @@ class CSVImporter: for check_column in range(len(current_row)): # If the data type is TEXT, break, because there is nothing to change. This data type works in every # case. - if self.data_types[check_column] == "TEXT": - break + if self.data_types[check_column] != "TEXT": + # Get the current value. + value = current_row[check_column] + # Get the data type of the current value. + data_type = self.get_data_type(value) - # Get the current value. - value = current_row[check_column] - # Get the data type of the current value. - data_type = self.get_data_type(value) - - # If the data type is not null, write the data type in the data type list. - if data_type != "NULL": - self.data_types[check_column] = data_type + # If the data type is not null, write the data type in the data type list. # TODO: Debug this data + # type shit + if data_type != "NULL" or (self.data_types[check_column] != "REAL" and data_type == "INT"): + self.data_types[check_column] = data_type def get_data_type(self, value): """ @@ -177,6 +167,8 @@ class CSVImporter: # If the name of the column should be checked, check it. if check_ddl: current_column = self.check_ddl_parameter(current_column) + # Write the corrected name back in the csv data. + self.csv_data[0][column_count] = current_column # Define the current column with its name, its data type and its comma value. create_column = "{} {}{}\n".format(current_column, self.data_types[column_count], comma_value) @@ -205,6 +197,77 @@ class CSVImporter: # Check the name of the table. self.table_name = self.check_ddl_parameter(self.table_name) + def create_insert_queries(self): + # TODO: docu + work_data_list = copy.copy(self.csv_data) + del work_data_list[0] + + chunk_size = 5000 + + work_data_list = [work_data_list[i * chunk_size:(i+1) * chunk_size] + for i in range((len(work_data_list) + chunk_size - 1) // chunk_size)] + + for sub_data_list in work_data_list: + insert_query = self.create_insert_query_begin() + parameter_list = [] + + for row_count in range(len(sub_data_list)): + value_query = "(" + row = sub_data_list[row_count] + for value_count in range(len(row)): + if value_count != len(row)-1: + comma_value = ", " + + else: + comma_value = "" + + value_query = "{}%s{}".format(value_query, comma_value) + value = row[value_count] + + if value == self.null_type: + value = None + + parameter_list.append(value) + + if row_count != len(sub_data_list)-1: + comma_value = ", " + + else: + comma_value = ";" + + value_query = "{}){}".format(value_query, comma_value) + + insert_query = "{}{}".format(insert_query, value_query) + + self.execute_insert_query(insert_query, parameter_list) + + def execute_insert_query(self, insert_query, insert_parameters): + # TODO: docu + self.database_query_executor.database_query = insert_query + self.database_query_executor.database_query_parameter = insert_parameters + self.database_query_executor.submit_and_execute_query() + + def create_insert_query_begin(self): + # TODO: docu + insert_query = "INSERT INTO {} (".format(self.table_name) + + header = self.csv_data[0] + + for column_count in range(len(header)): + # Define the comma to set after the definition of the column. + if column_count != len(header)-1: + comma_value = ", " + + else: + comma_value = "" + + insert_column = "{}{}".format(header[column_count], comma_value) + insert_query = "{}{}".format(insert_query, insert_column) + + insert_query = "{}) VALUES ".format(insert_query) + + return insert_query + @staticmethod def check_ddl_parameter(parameter): """ @@ -238,9 +301,10 @@ class CSVImporter: if __name__ == "__main__": csv_importer = CSVImporter(global_connection_factory.get_database_connection("localhost", "testuser", "testdb"), - "/home/sqlea/test.csv", delimiter=";", table_name="new_test_table") + "/home/sqlea/fl.csv", delimiter=";", table_name="new_test_table") if csv_importer.check_existence_csv_file() is True: csv_importer.parse_csv_file() csv_importer.assume_data_types() csv_importer.get_create_statement() + csv_importer.create_insert_queries() diff --git a/pygadmin/widgets/csv_import.py b/pygadmin/widgets/csv_import.py new file mode 100644 index 0000000..36c3892 --- /dev/null +++ b/pygadmin/widgets/csv_import.py @@ -0,0 +1,83 @@ +import sys +import time + +from PyQt5.QtWidgets import QDialog, QGridLayout, QLabel, QApplication, QPushButton, QMessageBox + +from pygadmin.connectionfactory import global_connection_factory +from pygadmin.csv_importer import CSVImporter +from pygadmin.widgets.widget_icon_adder import IconAdder + + +class CSVImportDialog(QDialog): + # TODO: docu + def __init__(self, database_connection, csv_file, delimiter): + super().__init__() + # Add the pygadmin icon as window icon. + icon_adder = IconAdder() + icon_adder.add_icon_to_widget(self) + + self.csv_importer = CSVImporter(database_connection, csv_file, delimiter, table_name="new_test_table", + null_type="") + + if self.csv_importer.check_existence_csv_file(): + self.init_ui() + self.init_grid() + self.csv_importer.database_query_executor.result_data.connect(self.show_success) + self.csv_importer.database_query_executor.error.connect(self.show_error) + + else: + self.init_error_ui(csv_file) + + def init_ui(self): + self.csv_importer.parse_csv_file() + self.csv_importer.assume_data_types() + self.create_statement_label = QLabel(self.csv_importer.get_create_statement()) + self.insert_button = QPushButton("Insert") + self.create_button = QPushButton("Create") + self.insert_button.clicked.connect(self.insert_data) + self.create_button.clicked.connect(self.create_table) + self.show() + + def init_grid(self): + grid_layout = QGridLayout(self) + grid_layout.addWidget(self.create_statement_label, 0, 0) + grid_layout.addWidget(self.insert_button, 0, 1) + grid_layout.addWidget(self.create_button, 0, 2) + grid_layout.setSpacing(10) + self.setLayout(grid_layout) + + def init_error_ui(self, csv_file): + # Get the layout as grid layout. + grid_layout = QGridLayout(self) + # Add a label with an error. + grid_layout.addWidget(QLabel("The given csv file {} is invalid".format(csv_file)), 0, 0) + self.setLayout(grid_layout) + self.setMaximumSize(10, 100) + self.showMaximized() + # Set the title to an error title. + self.setWindowTitle("File Error") + self.show() + + def insert_data(self): + begin = time.time() + self.csv_importer.create_insert_queries() + end = time.time() + print("Runtime: {}".format(end-begin)) + + def create_table(self): + self.csv_importer.create_table_for_csv_data() + + def show_success(self, result): + QMessageBox.information(self, "Success", "The result is {}".format(result)) + + def show_error(self, error_message): + QMessageBox.critical(self, "Error", "{}".format(error_message)) + print(error_message) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + csv_import = CSVImportDialog(global_connection_factory.get_database_connection("localhost", "testuser", "testdb"), + "/home/sqlea/fl.csv", ",") + sys.exit(app.exec()) +