From 4b87b560acfe24717e5d5b28eccb2b88d10e4a08 Mon Sep 17 00:00:00 2001 From: blue Date: Sat, 23 Dec 2023 17:23:38 -0300 Subject: [PATCH] session creation --- database/dbinterface.h | 1 + database/migrations/m0.sql | 4 ++- database/mysql/mysql.cpp | 22 +++++++++++-- database/mysql/mysql.h | 1 + database/mysql/statement.cpp | 61 +++++++++++++++++++++++++++++------- database/mysql/statement.h | 5 ++- handler/login.cpp | 22 +++++++++---- server/CMakeLists.txt | 2 ++ server/server.cpp | 16 ++++++++-- server/server.h | 6 ++++ server/session.cpp | 18 +++++++++++ server/session.h | 19 +++++++++++ 12 files changed, 152 insertions(+), 25 deletions(-) create mode 100644 server/session.cpp create mode 100644 server/session.h diff --git a/database/dbinterface.h b/database/dbinterface.h index 5f018e8..e2c0ebe 100644 --- a/database/dbinterface.h +++ b/database/dbinterface.h @@ -43,6 +43,7 @@ public: virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; virtual std::string getAccountHash(const std::string& login) = 0; + virtual unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) = 0; protected: DBInterface(Type type); diff --git a/database/migrations/m0.sql b/database/migrations/m0.sql index e52e9ef..b6e9316 100644 --- a/database/migrations/m0.sql +++ b/database/migrations/m0.sql @@ -37,8 +37,10 @@ CREATE TABLE IF NOT EXISTS sessions ( `owner` INTEGER UNSIGNED NOT NULL, `started` TIMESTAMP DEFAULT UTC_TIMESTAMP(), `latest` TIMESTAMP DEFAULT UTC_TIMESTAMP(), - `salt` CHAR(16), + `access` CHAR(32), + `renew` CHAR(32), `persist` BOOLEAN NOT NULL, + `device` TEXT, FOREIGN KEY (owner) REFERENCES accounts(id) ); diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index 1e40cf4..c101220 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -18,6 +18,8 @@ constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `p constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id"; constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?"; constexpr const char* selectHash = "SELECT password FROM accounts where login = ?"; +constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `access`, `renew`, `persist`, `device`)" + " SELECT accounts.id, ?, ?, true, ? FROM accounts WHERE accounts.login = ?"; static const std::filesystem::path buildSQLPath = "database"; @@ -253,14 +255,30 @@ std::string MySQL::getAccountHash(const std::string& login) { getHash.bind(l.data(), MYSQL_TYPE_STRING); getHash.execute(); - std::vector> result = getHash.fetchResult(); + std::vector> result = getHash.fetchResult(); if (result.empty()) throw NoLogin("Couldn't find login " + l); if (result[0].empty()) throw std::runtime_error("Error with the query \"selectHash\""); - return result[0][0]; + return std::any_cast(result[0][0]); +} + +unsigned int MySQL::createSession(const std::string& login, const std::string& access, const std::string& renew) { + std::string l = login, a = access, r = renew; + static std::string testingDevice("Testing..."); + + MYSQL* con = &connection; + + Statement session(con, createSessionQuery); + session.bind(a.data(), MYSQL_TYPE_STRING); + session.bind(r.data(), MYSQL_TYPE_STRING); + session.bind(testingDevice.data(), MYSQL_TYPE_STRING); + session.bind(l.data(), MYSQL_TYPE_STRING); + session.execute(); + + return lastInsertedId(); } unsigned int MySQL::lastInsertedId() { diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index 75e865d..8344ce2 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -32,6 +32,7 @@ public: unsigned int registerAccount(const std::string& login, const std::string& hash) override; std::string getAccountHash(const std::string& login) override; + unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) override; private: void executeFile(const std::filesystem::path& relativePath); diff --git a/database/mysql/statement.cpp b/database/mysql/statement.cpp index acfb0e6..9a1fee9 100644 --- a/database/mysql/statement.cpp +++ b/database/mysql/statement.cpp @@ -3,8 +3,6 @@ #include "statement.h" -#include - #include "mysqld_error.h" #include "database/exceptions.h" @@ -66,7 +64,7 @@ void MySQL::Statement::execute() { } } -std::vector> MySQL::Statement::fetchResult() { +std::vector> MySQL::Statement::fetchResult() { MYSQL_STMT* raw = stmt.get(); if (mysql_stmt_store_result(raw) != 0) throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here @@ -78,9 +76,9 @@ std::vector> MySQL::Statement::fetchResult() { std::unique_ptr mt(meta); unsigned int numColumns = mysql_num_fields(meta); MYSQL_BIND bind[numColumns]; - memset(bind, 0, sizeof(bind)); + std::memset(bind, 0, sizeof(bind)); - std::vector line(numColumns); + std::vector line(numColumns); std::vector lengths(numColumns); for (unsigned int i = 0; i < numColumns; ++i) { MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i); @@ -88,14 +86,32 @@ std::vector> MySQL::Statement::fetchResult() { switch (field->type) { case MYSQL_TYPE_STRING: case MYSQL_TYPE_VAR_STRING: - case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_VARCHAR: { + line[i] = std::string(); + std::string& str = std::any_cast(line[i]); + str.resize(field->length); + bind[i].buffer = str.data(); + } break; + case MYSQL_TYPE_TINY: + line[i] = uint8_t{0}; + bind[i].buffer = &std::any_cast(line[i]); + break; + case MYSQL_TYPE_SHORT: + line[i] = uint16_t{0}; + bind[i].buffer = &std::any_cast(line[i]); + break; + case MYSQL_TYPE_LONG: + line[i] = uint32_t{0}; + bind[i].buffer = &std::any_cast(line[i]); + break; + case MYSQL_TYPE_LONGLONG: + line[i] = uint64_t{0}; + bind[i].buffer = &std::any_cast(line[i]); break; default: throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type)); } - line[i].resize(field->length); bind[i].buffer_type = field->type; - bind[i].buffer = line[i].data(); bind[i].buffer_length = field->length; bind[i].length = &lengths[i]; } @@ -103,11 +119,32 @@ std::vector> MySQL::Statement::fetchResult() { if (mysql_stmt_bind_result(raw, bind) != 0) throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw)); - std::vector> result; + std::vector> result; while (mysql_stmt_fetch(raw) == 0) { - std::vector& row = result.emplace_back(numColumns); - for (unsigned int i = 0; i < numColumns; ++i) - row[i] = std::string(line[i].data(), lengths[i]); + std::vector& row = result.emplace_back(numColumns); + for (unsigned int i = 0; i < numColumns; ++i) { + switch (bind[i].buffer_type) { + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_VARCHAR: { + row[i] = std::string(std::any_cast(line[i]).data(), lengths[i]); + } break; + case MYSQL_TYPE_TINY: + row[i] = std::any_cast(line[i]); + break; + case MYSQL_TYPE_SHORT: + row[i] = std::any_cast(line[i]); + break; + case MYSQL_TYPE_LONG: + row[i] = std::any_cast(line[i]); + break; + case MYSQL_TYPE_LONGLONG: + row[i] = std::any_cast(line[i]); + break; + default: + throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(bind[i].buffer_type)); + } + } } return result; diff --git a/database/mysql/statement.h b/database/mysql/statement.h index 3209155..369f0fa 100644 --- a/database/mysql/statement.h +++ b/database/mysql/statement.h @@ -3,7 +3,10 @@ #pragma once +#include #include +#include +#include #include "mysql.h" @@ -18,7 +21,7 @@ public: void bind(void* value, enum_field_types type, bool usigned = false); void execute(); - std::vector> fetchResult(); + std::vector> fetchResult(); private: std::unique_ptr stmt; diff --git a/handler/login.cpp b/handler/login.cpp index f09ecec..2937a77 100644 --- a/handler/login.cpp +++ b/handler/login.cpp @@ -33,24 +33,34 @@ void Handler::Login::handle(Request& request) { try { success = server->validatePassword(login, password); } catch (const DBInterface::NoLogin& e) { - std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; + std::cerr << "Exception on logging in:\n\t" << e.what() << std::endl; return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing } catch (const std::exception& e) { - std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; + std::cerr << "Exception on logging in:\n\t" << e.what() << std::endl; return error(request, Result::unknownError, Response::Status::internalError); } catch (...) { - std::cerr << "Unknown exception on registration" << std::endl; + std::cerr << "Unknown exception on ogging in" << std::endl; return error(request, Result::unknownError, Response::Status::internalError); } if (!success) return error(request, Result::noLogin, Response::Status::badRequest); - //TODO opening the session - - Response& res = request.createResponse(); nlohmann::json body = nlohmann::json::object(); body["result"] = Result::success; + try { + Session& session = server->openSession(login); + body["accessToken"] = session.getAccessToken(); + body["renewToken"] = session.getRenewToken(); + } catch (const std::exception& e) { + std::cerr << "Exception on opening a session:\n\t" << e.what() << std::endl; + return error(request, Result::unknownError, Response::Status::internalError); + } catch (...) { + std::cerr << "Unknown exception on opening a session" << std::endl; + return error(request, Result::unknownError, Response::Status::internalError); + } + + Response& res = request.createResponse(); res.setBody(body); res.send(); } diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 0e38cbe..9a1b2d8 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -1,11 +1,13 @@ set(HEADERS server.h router.h + session.h ) set(SOURCES server.cpp router.cpp + session.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/server/server.cpp b/server/server.cpp index a605012..cfb65ea 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -12,7 +12,7 @@ constexpr const char* pepper = "well, not much of a secret, huh?"; constexpr uint8_t currentDbVesion = 1; -constexpr const char* randomChars = "0123456789abcdef"; +constexpr const char* randomChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; constexpr uint8_t saltSize = 16; constexpr uint8_t hashSize = 32; constexpr uint8_t hashParallel = 1; @@ -24,7 +24,8 @@ Server::Server(): requestCount(0), serverName(std::nullopt), router(), - db() + db(), + sessions() { std::cout << "Startig pica..." << std::endl; @@ -78,7 +79,7 @@ void Server::handleRequest(std::unique_ptr request) { std::string Server::generateRandomString(std::size_t length) { std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution distribution(0, std::strlen(randomChars)); + std::uniform_int_distribution distribution(0, std::strlen(randomChars) - 1); std::string result(length, 0); for (size_t i = 0; i < length; ++i) @@ -125,3 +126,12 @@ bool Server::validatePassword(const std::string& login, const std::string& passw throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result)); } } + +Session& Server::openSession(const std::string& login) { + std::string accessToken = generateRandomString(32); + std::string renewToken = generateRandomString(32); + unsigned int sessionId = db->createSession(login, accessToken, renewToken); + + std::unique_ptr& session = sessions[accessToken] = std::make_unique(sessionId, accessToken, renewToken); + return *session.get(); +} diff --git a/server/server.h b/server/server.h index 4291ee9..5e5b321 100644 --- a/server/server.h +++ b/server/server.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,7 @@ #include "request/request.h" #include "response/response.h" #include "router.h" +#include "session.h" #include "database/dbinterface.h" #include "utils/helpers.h" #include "config.h" @@ -33,15 +35,19 @@ public: unsigned int registerAccount(const std::string& login, const std::string& password); bool validatePassword(const std::string& login, const std::string& password); + Session& openSession(const std::string& login); private: void handleRequest(std::unique_ptr request); static std::string generateRandomString(std::size_t length); private: + using Sessions = std::map>; + bool terminating; uint64_t requestCount; std::optional serverName; Router router; std::unique_ptr db; + Sessions sessions; }; diff --git a/server/session.cpp b/server/session.cpp new file mode 100644 index 0000000..2e203be --- /dev/null +++ b/server/session.cpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "session.h" + +Session::Session(unsigned int id, const std::string& access, const std::string& renew): + id(id), + access(access), + renew(renew) +{} + +std::string Session::getAccessToken() const { + return access; +} + +std::string Session::getRenewToken() const { + return renew; +} diff --git a/server/session.h b/server/session.h new file mode 100644 index 0000000..6d23636 --- /dev/null +++ b/server/session.h @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include + +class Session { +public: + Session(unsigned int id, const std::string& access, const std::string& renew); + + std::string getAccessToken() const; + std::string getRenewToken() const; + +private: + unsigned int id; + std::string access; + std::string renew; +};