diff --git a/CMakeLists.txt b/CMakeLists.txt index dba39c4..2eac80e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ message("Compile options: " ${COMPILE_OPTIONS_STRING}) find_package(nlohmann_json REQUIRED) find_package(FCGI REQUIRED) find_package(Argon2 REQUIRED) +find_package(Threads REQUIRED) add_executable(${PROJECT_NAME} main.cpp) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -63,6 +64,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE FCGI::FCGI++ nlohmann_json::nlohmann_json Argon2::Argon2 + Threads::Threads ) install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/database/CMakeLists.txt b/database/CMakeLists.txt index a0adbc9..d903715 100644 --- a/database/CMakeLists.txt +++ b/database/CMakeLists.txt @@ -1,11 +1,15 @@ set(HEADERS - dbinterface.h + interface.h exceptions.h + pool.h + resource.h ) set(SOURCES - dbinterface.cpp + interface.cpp exceptions.cpp + pool.cpp + resource.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/database/exceptions.cpp b/database/exceptions.cpp index 5d555e2..d76ff93 100644 --- a/database/exceptions.cpp +++ b/database/exceptions.cpp @@ -3,18 +3,18 @@ #include "exceptions.h" -DBInterface::Duplicate::Duplicate(const std::string& text): +DB::Duplicate::Duplicate(const std::string& text): std::runtime_error(text) {} -DBInterface::DuplicateLogin::DuplicateLogin(const std::string& text): +DB::DuplicateLogin::DuplicateLogin(const std::string& text): Duplicate(text) {} -DBInterface::EmptyResult::EmptyResult(const std::string& text): +DB::EmptyResult::EmptyResult(const std::string& text): std::runtime_error(text) {} -DBInterface::NoLogin::NoLogin(const std::string& text): +DB::NoLogin::NoLogin(const std::string& text): EmptyResult(text) {} diff --git a/database/exceptions.h b/database/exceptions.h index 8d447e2..8b677d8 100644 --- a/database/exceptions.h +++ b/database/exceptions.h @@ -3,24 +3,27 @@ #pragma once -#include "dbinterface.h" +#include +#include -class DBInterface::Duplicate : public std::runtime_error { +namespace DB { +class Duplicate : public std::runtime_error { public: explicit Duplicate(const std::string& text); }; -class DBInterface::DuplicateLogin : public DBInterface::Duplicate { +class DuplicateLogin : public Duplicate { public: explicit DuplicateLogin(const std::string& text); }; -class DBInterface::EmptyResult : public std::runtime_error { +class EmptyResult : public std::runtime_error { public: explicit EmptyResult(const std::string& text); }; -class DBInterface::NoLogin : public DBInterface::EmptyResult { +class NoLogin : public EmptyResult { public: explicit NoLogin(const std::string& text); }; +} diff --git a/database/dbinterface.cpp b/database/interface.cpp similarity index 64% rename from database/dbinterface.cpp rename to database/interface.cpp index de98e39..563296c 100644 --- a/database/dbinterface.cpp +++ b/database/interface.cpp @@ -1,18 +1,18 @@ // SPDX-FileCopyrightText: 2023 Yury Gubich // SPDX-License-Identifier: GPL-3.0-or-later -#include "dbinterface.h" +#include "interface.h" #include "mysql/mysql.h" -DBInterface::DBInterface(Type type): +DB::Interface::Interface(Type type): type(type), state(State::disconnected) {} -DBInterface::~DBInterface() {} +DB::Interface::~Interface() {} -std::unique_ptr DBInterface::create(Type type) { +std::unique_ptr DB::Interface::create(Type type) { switch (type) { case Type::mysql: return std::make_unique(); @@ -21,6 +21,6 @@ std::unique_ptr DBInterface::create(Type type) { throw std::runtime_error("Unexpected database type: " + std::to_string((uint8_t)type)); } -DBInterface::State DBInterface::currentState() const { +DB::Interface::State DB::Interface::currentState() const { return state; } diff --git a/database/dbinterface.h b/database/interface.h similarity index 83% rename from database/dbinterface.h rename to database/interface.h index e2c0ebe..7b6e204 100644 --- a/database/dbinterface.h +++ b/database/interface.h @@ -8,7 +8,8 @@ #include #include -class DBInterface { +namespace DB { +class Interface { public: enum class Type { mysql @@ -18,19 +19,14 @@ public: connecting, connected }; - static std::unique_ptr create(Type type); + static std::unique_ptr create(Type type); - virtual ~DBInterface(); + virtual ~Interface(); State currentState() const; const Type type; - class Duplicate; - class DuplicateLogin; - class EmptyResult; - class NoLogin; - public: virtual void connect(const std::string& path) = 0; virtual void disconnect() = 0; @@ -46,8 +42,9 @@ public: virtual unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) = 0; protected: - DBInterface(Type type); + Interface(Type type); protected: State state; }; +} diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index c101220..b436b03 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -23,8 +23,8 @@ constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `acce static const std::filesystem::path buildSQLPath = "database"; -MySQL::MySQL(): - DBInterface(Type::mysql), +DB::MySQL::MySQL(): + Interface(Type::mysql), connection(), login(), password(), @@ -33,11 +33,11 @@ MySQL::MySQL(): mysql_init(&connection); } -MySQL::~MySQL() { +DB::MySQL::~MySQL() { mysql_close(&connection); } -void MySQL::connect(const std::string& path) { +void DB::MySQL::connect(const std::string& path) { if (state != State::disconnected) return; @@ -59,7 +59,7 @@ void MySQL::connect(const std::string& path) { state = State::connected; } -void MySQL::setCredentials(const std::string& login, const std::string& password) { +void DB::MySQL::setCredentials(const std::string& login, const std::string& password) { if (MySQL::login == login && MySQL::password == password) return; @@ -81,7 +81,7 @@ void MySQL::setCredentials(const std::string& login, const std::string& password throw std::runtime_error(std::string("Error changing credetials: ") + mysql_error(con)); } -void MySQL::setDatabase(const std::string& database) { +void DB::MySQL::setDatabase(const std::string& database) { if (MySQL::database == database) return; @@ -97,7 +97,7 @@ void MySQL::setDatabase(const std::string& database) { throw std::runtime_error(std::string("Error changing db: ") + mysql_error(con)); } -void MySQL::disconnect() { +void DB::MySQL::disconnect() { if (state == State::disconnected) return; @@ -106,7 +106,7 @@ void MySQL::disconnect() { mysql_init(con); //this is ridiculous! } -void MySQL::executeFile(const std::filesystem::path& relativePath) { +void DB::MySQL::executeFile(const std::filesystem::path& relativePath) { MYSQL* con = &connection; std::filesystem::path path = sharedPath() / relativePath; if (!std::filesystem::exists(path)) @@ -138,7 +138,7 @@ void MySQL::executeFile(const std::filesystem::path& relativePath) { } } -uint8_t MySQL::getVersion() { +uint8_t DB::MySQL::getVersion() { MYSQL* con = &connection; int result = mysql_query(con, versionQuery); @@ -161,14 +161,14 @@ uint8_t MySQL::getVersion() { return 0; } -void MySQL::setVersion(uint8_t version) { +void DB::MySQL::setVersion(uint8_t version) { std::string strVersion = std::to_string(version); Statement statement(&connection, updateQuery); statement.bind(strVersion.data(), MYSQL_TYPE_VAR_STRING); statement.execute(); } -void MySQL::migrate(uint8_t targetVersion) { +void DB::MySQL::migrate(uint8_t targetVersion) { uint8_t currentVersion = getVersion(); while (currentVersion < targetVersion) { @@ -190,7 +190,7 @@ void MySQL::migrate(uint8_t targetVersion) { std::cout << "Database is now on actual version " << std::to_string(targetVersion) << std::endl; } -std::optional MySQL::getComment(std::string& string) { +std::optional DB::MySQL::getComment(std::string& string) { ltrim(string); if (string.length() < 2) return std::nullopt; @@ -218,7 +218,7 @@ std::optional MySQL::getComment(std::string& string) { return std::nullopt; } -unsigned int MySQL::registerAccount(const std::string& login, const std::string& hash) { +unsigned int DB::MySQL::registerAccount(const std::string& login, const std::string& hash) { //TODO validate filed lengths! MYSQL* con = &connection; MySQL::Transaction txn(con); @@ -247,7 +247,7 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string& return id; } -std::string MySQL::getAccountHash(const std::string& login) { +std::string DB::MySQL::getAccountHash(const std::string& login) { std::string l = login; MYSQL* con = &connection; @@ -265,7 +265,7 @@ std::string MySQL::getAccountHash(const std::string& login) { return std::any_cast(result[0][0]); } -unsigned int MySQL::createSession(const std::string& login, const std::string& access, const std::string& renew) { +unsigned int DB::MySQL::createSession(const std::string& login, const std::string& access, const std::string& renew) { std::string l = login, a = access, r = renew; static std::string testingDevice("Testing..."); @@ -281,7 +281,7 @@ unsigned int MySQL::createSession(const std::string& login, const std::string& a return lastInsertedId(); } -unsigned int MySQL::lastInsertedId() { +unsigned int DB::MySQL::lastInsertedId() { MYSQL* con = &connection; int result = mysql_query(con, lastIdQuery); diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index 8344ce2..7edc2ae 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -9,10 +9,11 @@ #include -#include "database/dbinterface.h" +#include "database/interface.h" #include "utils/helpers.h" -class MySQL : public DBInterface { +namespace DB { +class MySQL : public Interface { class Statement; class Transaction; @@ -51,3 +52,4 @@ struct ResDeleter { } }; }; +} diff --git a/database/mysql/statement.cpp b/database/mysql/statement.cpp index 9a1fee9..3e8a841 100644 --- a/database/mysql/statement.cpp +++ b/database/mysql/statement.cpp @@ -9,7 +9,7 @@ static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME); -MySQL::Statement::Statement(MYSQL* connection, const char* statement): +DB::MySQL::Statement::Statement(MYSQL* connection, const char* statement): stmt(mysql_stmt_init(connection)), param() { @@ -18,7 +18,7 @@ MySQL::Statement::Statement(MYSQL* connection, const char* statement): throw std::runtime_error(std::string("Error preparing statement: ") + mysql_stmt_error(stmt.get())); } -void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) { +void DB::MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) { MYSQL_BIND& result = param.emplace_back(); std::memset(&result, 0, sizeof(result)); @@ -45,7 +45,7 @@ void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) { } } -void MySQL::Statement::execute() { +void DB::MySQL::Statement::execute() { MYSQL_STMT* raw = stmt.get(); int result = mysql_stmt_bind_param(raw, param.data()); if (result != 0) @@ -64,7 +64,7 @@ void MySQL::Statement::execute() { } } -std::vector> MySQL::Statement::fetchResult() { +std::vector> DB::MySQL::Statement::fetchResult() { MYSQL_STMT* raw = stmt.get(); if (mysql_stmt_store_result(raw) != 0) throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here diff --git a/database/mysql/statement.h b/database/mysql/statement.h index 369f0fa..68d9563 100644 --- a/database/mysql/statement.h +++ b/database/mysql/statement.h @@ -10,6 +10,7 @@ #include "mysql.h" +namespace DB { class MySQL::Statement { struct STMTDeleter { void operator () (MYSQL_STMT* stmt) { @@ -27,3 +28,4 @@ private: std::unique_ptr stmt; std::vector param; }; +} diff --git a/database/mysql/transaction.cpp b/database/mysql/transaction.cpp index 1a9b92c..24cda9f 100644 --- a/database/mysql/transaction.cpp +++ b/database/mysql/transaction.cpp @@ -3,7 +3,7 @@ #include "transaction.h" -MySQL::Transaction::Transaction(MYSQL* connection): +DB::MySQL::Transaction::Transaction(MYSQL* connection): con(connection), opened(false) { @@ -13,12 +13,12 @@ MySQL::Transaction::Transaction(MYSQL* connection): opened = true; } -MySQL::Transaction::~Transaction() { +DB::MySQL::Transaction::~Transaction() { if (opened) abort(); } -void MySQL::Transaction::commit() { +void DB::MySQL::Transaction::commit() { if (mysql_commit(con) != 0) throw std::runtime_error(std::string("Failed to commit transaction") + mysql_error(con)); @@ -27,7 +27,7 @@ void MySQL::Transaction::commit() { throw std::runtime_error(std::string("Failed to return autocommit") + mysql_error(con)); } -void MySQL::Transaction::abort() { +void DB::MySQL::Transaction::abort() { opened = false; if (mysql_rollback(con) != 0) throw std::runtime_error(std::string("Failed to rollback transaction") + mysql_error(con)); diff --git a/database/mysql/transaction.h b/database/mysql/transaction.h index 5a6a15c..b4b97d3 100644 --- a/database/mysql/transaction.h +++ b/database/mysql/transaction.h @@ -5,6 +5,7 @@ #include "mysql.h" +namespace DB { class MySQL::Transaction { public: Transaction(MYSQL* connection); @@ -17,3 +18,4 @@ private: MYSQL* con; bool opened; }; +} diff --git a/database/pool.cpp b/database/pool.cpp new file mode 100644 index 0000000..53c08f7 --- /dev/null +++ b/database/pool.cpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "pool.h" + +DB::Pool::Pool (Private): + std::enable_shared_from_this(), + mutex(), + conditional(), + interfaces() +{} + +DB::Pool::~Pool () { +} + +std::shared_ptr DB::Pool::create () { + return std::make_shared(Private()); +} + +void DB::Pool::addInterfaces ( + Interface::Type type, + std::size_t amount, + const std::string & login, + const std::string & password, + const std::string & database, + const std::string& path +) { + std::unique_lock lock(mutex); + for (std::size_t i = 0; i < amount; ++i) { + const std::unique_ptr& ref = interfaces.emplace(Interface::create(type)); + ref->setCredentials(login, password); + ref->setDatabase(database); + ref->connect(path); + } + + lock.unlock(); + conditional.notify_all(); +} + +DB::Resource DB::Pool::request () { + std::unique_lock lock(mutex); + while (interfaces.empty()) + conditional.wait(lock); + + std::unique_ptr interface = std::move(interfaces.front()); + interfaces.pop(); + return Resource(std::move(interface), shared_from_this()); +} + +void DB::Pool::free (std::unique_ptr interface) { + std::unique_lock lock(mutex); + + interfaces.push(std::move(interface)); + + lock.unlock(); + conditional.notify_one(); +} diff --git a/database/pool.h b/database/pool.h new file mode 100644 index 0000000..1040781 --- /dev/null +++ b/database/pool.h @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include +#include +#include +#include +#include + +#include "interface.h" +#include "resource.h" + +namespace DB { +class Pool : public std::enable_shared_from_this { + struct Private {}; + friend class Resource; + + void free(std::unique_ptr interface); + +public: + Pool(Private); + Pool(const Pool&) = delete; + Pool(Pool&&) = delete; + ~Pool(); + Pool& operator = (const Pool&) = delete; + Pool& operator = (Pool&&) = delete; + + static std::shared_ptr create(); + Resource request(); + void addInterfaces( + Interface::Type type, + std::size_t amount, + const std::string& login, + const std::string& password, + const std::string& database, + const std::string& path + ); + +private: + std::mutex mutex; + std::condition_variable conditional; + std::queue> interfaces; + +}; +} diff --git a/database/resource.cpp b/database/resource.cpp new file mode 100644 index 0000000..4021a51 --- /dev/null +++ b/database/resource.cpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "resource.h" + +#include "pool.h" + +DB::Resource::Resource ( + std::unique_ptr interface, + std::weak_ptr parent +): + parent(parent), + interface(std::move(interface)) +{} + +DB::Resource::Resource(Resource&& other): + parent(other.parent), + interface(std::move(other.interface)) +{} + +DB::Resource::~Resource() { + if (!interface) + return; + + if (std::shared_ptr p = parent.lock()) + p->free(std::move(interface)); +} + +DB::Resource& DB::Resource::operator = (Resource&& other) { + parent = other.parent; + interface = std::move(other.interface); + + return *this; +} + +DB::Interface* DB::Resource::operator -> () { + return interface.get(); +} diff --git a/database/resource.h b/database/resource.h new file mode 100644 index 0000000..d437efd --- /dev/null +++ b/database/resource.h @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include + +#include "interface.h" + +namespace DB { +class Pool; + +class Resource { + friend class Pool; + Resource(std::unique_ptr interface, std::weak_ptr parent); + +public: + Resource(const Resource&) = delete; + Resource(Resource&& other); + ~Resource(); + + Resource& operator = (const Resource&) = delete; + Resource& operator = (Resource&& other); + + Interface* operator -> (); + +private: + std::weak_ptr parent; + std::unique_ptr interface; +}; +} diff --git a/handler/login.cpp b/handler/login.cpp index c72cdd2..4cd1ef9 100644 --- a/handler/login.cpp +++ b/handler/login.cpp @@ -32,7 +32,7 @@ void Handler::Login::handle(Request& request) { bool success = false; try { success = server->validatePassword(login, password); - } catch (const DBInterface::NoLogin& e) { + } catch (const DB::NoLogin& e) { std::cerr << "Exception on logging in:\n\t" << e.what() << std::endl; return error(request, Result::wrongCredentials, Response::Status::badRequest); } catch (const std::exception& e) { diff --git a/handler/register.cpp b/handler/register.cpp index 57fefb9..ee485c8 100644 --- a/handler/register.cpp +++ b/handler/register.cpp @@ -35,7 +35,7 @@ void Handler::Register::handle(Request& request) { try { server->registerAccount(login, password); - } catch (const DBInterface::DuplicateLogin& e) { + } catch (const DB::DuplicateLogin& e) { std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; return error(request, Result::loginExists, Response::Status::conflict); } catch (const std::exception& e) { diff --git a/request/CMakeLists.txt b/request/CMakeLists.txt index bca1871..19e567c 100644 --- a/request/CMakeLists.txt +++ b/request/CMakeLists.txt @@ -1,10 +1,13 @@ set(HEADERS request.h + redirect.h + redirectable.h ) set(SOURCES request.cpp + redirect.cpp ) target_sources(pica PRIVATE ${SOURCES}) diff --git a/request/accepting.h b/request/accepting.h index d1b28b8..de7b626 100644 --- a/request/accepting.h +++ b/request/accepting.h @@ -9,5 +9,6 @@ class Accepting { public: + virtual ~Accepting() {}; virtual void accept(std::unique_ptr request) = 0; -}; \ No newline at end of file +}; diff --git a/server/server.cpp b/server/server.cpp index 4bc8ad6..4bf05d4 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -13,6 +13,11 @@ #include "handler/login.h" constexpr const char* pepper = "well, not much of a secret, huh?"; +constexpr const char* dbLogin = "pica"; +constexpr const char* dbPassword = "pica"; +constexpr const char* dbName = "pica"; +constexpr const char* dbPath = "/run/mysqld/mysqld.sock"; +constexpr uint8_t dbConnectionsCount = 4; constexpr const char* randomChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; constexpr uint8_t saltSize = 16; @@ -28,18 +33,22 @@ Server::Server(): requestCount(0), serverName(std::nullopt), router(), - db(), + pool(DB::Pool::create()), sessions() { std::cout << "Startig pica..." << std::endl; - - db = DBInterface::create(DBInterface::Type::mysql); std::cout << "Database type: MySQL" << std::endl; + pool->addInterfaces( + DB::Interface::Type::mysql, + dbConnectionsCount, + dbLogin, + dbPassword, + dbName, + dbPath + ); - db->setCredentials("pica", "pica"); - db->setDatabase("pica"); + DB::Resource db = pool->request(); - db->connect("/run/mysqld/mysqld.sock"); db->migrate(currentDbVesion); router.addRoute(std::make_unique()); @@ -112,10 +121,12 @@ unsigned int Server::registerAccount(const std::string& login, const std::string if (result != ARGON2_OK) throw std::runtime_error(std::string("Hashing failed: ") + argon2_error_message(result)); + DB::Resource db = pool->request(); return db->registerAccount(login, hash); } bool Server::validatePassword(const std::string& login, const std::string& password) { + DB::Resource db = pool->request(); std::string hash = db->getAccountHash(login); std::string spiced = password + pepper; @@ -139,9 +150,10 @@ Session& Server::openSession(const std::string& login) { try { accessToken = generateRandomString(32); renewToken = generateRandomString(32); + DB::Resource db = pool->request(); sessionId = db->createSession(login, accessToken, renewToken); break; - } catch (const DBInterface::Duplicate& e) { + } catch (const DB::Duplicate& e) { std::cout << "Duplicate on creating session, trying again with different tokens"; } } while (--counter != 0); diff --git a/server/server.h b/server/server.h index 9c3e187..9d3319c 100644 --- a/server/server.h +++ b/server/server.h @@ -22,7 +22,7 @@ #include "response/response.h" #include "router.h" #include "session.h" -#include "database/dbinterface.h" +#include "database/pool.h" #include "utils/helpers.h" #include "config.h" @@ -49,6 +49,6 @@ private: uint64_t requestCount; std::optional serverName; Router router; - std::unique_ptr db; + std::shared_ptr pool; Sessions sessions; }; diff --git a/server/session.h b/server/session.h index b5b459e..910070f 100644 --- a/server/session.h +++ b/server/session.h @@ -10,6 +10,10 @@ class Session : public Accepting { public: Session(unsigned int id, const std::string& access, const std::string& renew); + Session(const Session&) = delete; + Session(Session&& other); + Session& operator = (const Session&) = delete; + Session& operator = (Session&& other); std::string getAccessToken() const; std::string getRenewToken() const;