From 319895db642bb120209b73572981ca364cd6e527 Mon Sep 17 00:00:00 2001 From: blue Date: Fri, 8 Dec 2023 19:26:16 -0300 Subject: [PATCH] some primitive database stuff --- database/CMakeLists.txt | 1 + database/dbinterface.h | 5 +++ database/migrations/CMakeLists.txt | 1 + database/migrations/m0.sql | 6 +++ database/mysql/CMakeLists.txt | 2 + database/mysql/mysql.cpp | 61 ++++++++++++++++++++++++++++++ database/mysql/mysql.h | 5 +++ database/mysql/statement.cpp | 47 +++++++++++++++++++++++ database/mysql/statement.h | 24 ++++++++++++ server/server.cpp | 14 +++++++ 10 files changed, 166 insertions(+) create mode 100644 database/migrations/CMakeLists.txt create mode 100644 database/migrations/m0.sql create mode 100644 database/mysql/statement.cpp create mode 100644 database/mysql/statement.h diff --git a/database/CMakeLists.txt b/database/CMakeLists.txt index cd024ca..8f5b864 100644 --- a/database/CMakeLists.txt +++ b/database/CMakeLists.txt @@ -9,3 +9,4 @@ set(SOURCES target_sources(pica PRIVATE ${SOURCES}) add_subdirectory(mysql) +add_subdirectory(migrations) diff --git a/database/dbinterface.h b/database/dbinterface.h index 80334f7..25f035e 100644 --- a/database/dbinterface.h +++ b/database/dbinterface.h @@ -3,6 +3,7 @@ #include #include #include +#include class DBInterface { public: @@ -28,6 +29,10 @@ public: virtual void setDatabase(const std::string& newDatabase) = 0; virtual void setCredentials(const std::string& login, const std::string& password) = 0; + virtual void executeFile(const std::string& path) = 0; + virtual uint8_t getVersion() = 0; + virtual void setVersion(uint8_t version) = 0; + protected: DBInterface(Type type); diff --git a/database/migrations/CMakeLists.txt b/database/migrations/CMakeLists.txt new file mode 100644 index 0000000..50e2762 --- /dev/null +++ b/database/migrations/CMakeLists.txt @@ -0,0 +1 @@ +configure_file(m0.sql m0.sql COPYONLY) diff --git a/database/migrations/m0.sql b/database/migrations/m0.sql new file mode 100644 index 0000000..951fa04 --- /dev/null +++ b/database/migrations/m0.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS system ( + `key` VARCHAR(32) PRIMARY KEY, + `value` TEXT +); + +INSERT INTO system (`key`, `value`) VALUES ('version', '0'); diff --git a/database/mysql/CMakeLists.txt b/database/mysql/CMakeLists.txt index c64280d..75dd5e1 100644 --- a/database/mysql/CMakeLists.txt +++ b/database/mysql/CMakeLists.txt @@ -1,9 +1,11 @@ set(HEADERS mysql.h + statement.h ) set(SOURCES mysql.cpp + statement.cpp ) find_package(MariaDB REQUIRED) diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index c9da2a3..c5616bc 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -1,5 +1,19 @@ #include "mysql.h" +#include + +#include "mysqld_error.h" + +#include "statement.h" + +constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'"; + +struct ResDeleter { + void operator () (MYSQL_RES* res) { + mysql_free_result(res); + } +}; + MySQL::MySQL(): DBInterface(Type::mysql), connection(), @@ -83,3 +97,50 @@ void MySQL::disconnect() { mysql_close(con); mysql_init(con); //this is ridiculous! } + +void MySQL::executeFile(const std::string& path) { + MYSQL* con = &connection; + std::ifstream inputFile(path); + std::string query; + while (std::getline(inputFile, query, ';')) { + int result = mysql_query(con, query.c_str()); + if (result != 0) { + int errcode = mysql_errno(con); + if (errcode == ER_EMPTY_QUERY) + continue; + + throw std::runtime_error("Error executing file " + path + ": " + mysql_error(con)); + } + + } +} + +uint8_t MySQL::getVersion() { + MYSQL* con = &connection; + int result = mysql_query(con, "SELECT value FROM system WHERE `key` = 'version'"); + + if (result != 0) { + unsigned int errcode = mysql_errno(con); + if (errcode == ER_NO_SUCH_TABLE) + return 0; + + throw std::runtime_error(std::string("Error executing retreiving version: ") + mysql_error(con)); + } + + std::unique_ptr res(mysql_store_result(con)); + if (!res) + throw std::runtime_error(std::string("Querying version returned no result: ") + mysql_error(con)); + + MYSQL_ROW row = mysql_fetch_row(res.get()); + if (row) + return std::stoi(row[0]); + else + return 0; +} + +void MySQL::setVersion(uint8_t version) { + std::string strVersion = std::to_string(version); + Statement statement(&connection, updateQuery); + statement.bind(strVersion.data(), MYSQL_TYPE_VAR_STRING); + statement.execute(); +} diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index 35b8243..67bfaaa 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -7,6 +7,7 @@ #include "database/dbinterface.h" class MySQL : public DBInterface { + class Statement; public: MySQL(); ~MySQL() override; @@ -16,6 +17,10 @@ public: void setCredentials(const std::string& login, const std::string& password) override; void setDatabase(const std::string& database) override; + void executeFile(const std::string& path) override; + uint8_t getVersion() override; + void setVersion(uint8_t version) override; + protected: MYSQL connection; std::string login; diff --git a/database/mysql/statement.cpp b/database/mysql/statement.cpp new file mode 100644 index 0000000..920e27b --- /dev/null +++ b/database/mysql/statement.cpp @@ -0,0 +1,47 @@ +#include "statement.h" + +#include + +static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME); + +MySQL::Statement::Statement(MYSQL* connection, const char* statement): + stmt(mysql_stmt_init(connection)), + param(), + lengths() +{ + int result = mysql_stmt_prepare(stmt.get(), statement, strlen(statement)); + if (result != 0) + throw std::runtime_error(std::string("Error preparing statement: ") + mysql_stmt_error(stmt.get())); +} + +void MySQL::Statement::bind(void* value, enum_field_types type) { + MYSQL_BIND& result = param.emplace_back(); + std::memset(&result, 0, sizeof(result)); + + result.buffer_type = type; + result.buffer = value; + + switch (type) { + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + result.length = &lengths.emplace_back(strlen(static_cast(value))); + break; + case MYSQL_TYPE_DATE: + result.length = &TIME_LENGTH; + break; + default: + lengths.pop_back(); + throw std::runtime_error("Type: " + std::to_string(type) + " is not yet supported in bind"); + break; + } +} + +void MySQL::Statement::execute() { + int result = mysql_stmt_bind_param(stmt.get(), param.data()); + if (result != 0) + throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(stmt.get())); + + result = mysql_stmt_execute(stmt.get()); + if (result != 0) + throw std::runtime_error(std::string("Error executing statement: ") + mysql_stmt_error(stmt.get())); +} diff --git a/database/mysql/statement.h b/database/mysql/statement.h new file mode 100644 index 0000000..4c44c0e --- /dev/null +++ b/database/mysql/statement.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "mysql.h" + + +class MySQL::Statement { + struct STMTDeleter { + void operator () (MYSQL_STMT* stmt) { + mysql_stmt_close(stmt); + }; + }; +public: + Statement(MYSQL* connection, const char* statement); + + void bind(void* value, enum_field_types type); + void execute(); + +private: + std::unique_ptr stmt; + std::vector param; + std::vector lengths; +}; diff --git a/server/server.cpp b/server/server.cpp index e2f13d0..baff091 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -15,13 +15,27 @@ Server::Server(): db->setCredentials("pica", "pica"); db->setDatabase("pica"); + bool connected = false; try { db->connect("/run/mysqld/mysqld.sock"); + connected = true; std::cout << "Successfully connected to the database" << std::endl; + } catch (const std::runtime_error& e) { std::cerr << "Couldn't connect to the database: " << e.what() << std::endl; } + if (connected) { + uint8_t version = db->getVersion(); + std::cout << "Database version is " << std::to_string(version) << std::endl; + if (version == 0) { + db->executeFile("database/migrations/m0.sql"); + std::cout << "Successfully migrated to version 1" << std::endl; + db->setVersion(1); + std::cout << "Database version is " << std::to_string(db->getVersion()) << " now" << std::endl; + } + } + router.addRoute("info", Server::info); router.addRoute("env", Server::printEnvironment); }