session creation

This commit is contained in:
Blue 2023-12-23 17:23:38 -03:00
parent 534c282226
commit 4b87b560ac
Signed by: blue
GPG Key ID: 9B203B252A63EE38
12 changed files with 152 additions and 25 deletions

View File

@ -43,6 +43,7 @@ public:
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
virtual std::string getAccountHash(const std::string& login) = 0; virtual std::string getAccountHash(const std::string& login) = 0;
virtual unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) = 0;
protected: protected:
DBInterface(Type type); DBInterface(Type type);

View File

@ -37,8 +37,10 @@ CREATE TABLE IF NOT EXISTS sessions (
`owner` INTEGER UNSIGNED NOT NULL, `owner` INTEGER UNSIGNED NOT NULL,
`started` TIMESTAMP DEFAULT UTC_TIMESTAMP(), `started` TIMESTAMP DEFAULT UTC_TIMESTAMP(),
`latest` TIMESTAMP DEFAULT UTC_TIMESTAMP(), `latest` TIMESTAMP DEFAULT UTC_TIMESTAMP(),
`salt` CHAR(16), `access` CHAR(32),
`renew` CHAR(32),
`persist` BOOLEAN NOT NULL, `persist` BOOLEAN NOT NULL,
`device` TEXT,
FOREIGN KEY (owner) REFERENCES accounts(id) FOREIGN KEY (owner) REFERENCES accounts(id)
); );

View File

@ -18,6 +18,8 @@ constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `p
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id"; constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?"; constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
constexpr const char* selectHash = "SELECT password FROM accounts where login = ?"; constexpr const char* selectHash = "SELECT password FROM accounts where login = ?";
constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `access`, `renew`, `persist`, `device`)"
" SELECT accounts.id, ?, ?, true, ? FROM accounts WHERE accounts.login = ?";
static const std::filesystem::path buildSQLPath = "database"; static const std::filesystem::path buildSQLPath = "database";
@ -253,14 +255,30 @@ std::string MySQL::getAccountHash(const std::string& login) {
getHash.bind(l.data(), MYSQL_TYPE_STRING); getHash.bind(l.data(), MYSQL_TYPE_STRING);
getHash.execute(); getHash.execute();
std::vector<std::vector<std::string>> result = getHash.fetchResult(); std::vector<std::vector<std::any>> result = getHash.fetchResult();
if (result.empty()) if (result.empty())
throw NoLogin("Couldn't find login " + l); throw NoLogin("Couldn't find login " + l);
if (result[0].empty()) if (result[0].empty())
throw std::runtime_error("Error with the query \"selectHash\""); throw std::runtime_error("Error with the query \"selectHash\"");
return result[0][0]; return std::any_cast<const std::string&>(result[0][0]);
}
unsigned int MySQL::createSession(const std::string& login, const std::string& access, const std::string& renew) {
std::string l = login, a = access, r = renew;
static std::string testingDevice("Testing...");
MYSQL* con = &connection;
Statement session(con, createSessionQuery);
session.bind(a.data(), MYSQL_TYPE_STRING);
session.bind(r.data(), MYSQL_TYPE_STRING);
session.bind(testingDevice.data(), MYSQL_TYPE_STRING);
session.bind(l.data(), MYSQL_TYPE_STRING);
session.execute();
return lastInsertedId();
} }
unsigned int MySQL::lastInsertedId() { unsigned int MySQL::lastInsertedId() {

View File

@ -32,6 +32,7 @@ public:
unsigned int registerAccount(const std::string& login, const std::string& hash) override; unsigned int registerAccount(const std::string& login, const std::string& hash) override;
std::string getAccountHash(const std::string& login) override; std::string getAccountHash(const std::string& login) override;
unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) override;
private: private:
void executeFile(const std::filesystem::path& relativePath); void executeFile(const std::filesystem::path& relativePath);

View File

@ -3,8 +3,6 @@
#include "statement.h" #include "statement.h"
#include <cstring>
#include "mysqld_error.h" #include "mysqld_error.h"
#include "database/exceptions.h" #include "database/exceptions.h"
@ -66,7 +64,7 @@ void MySQL::Statement::execute() {
} }
} }
std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() { std::vector<std::vector<std::any>> MySQL::Statement::fetchResult() {
MYSQL_STMT* raw = stmt.get(); MYSQL_STMT* raw = stmt.get();
if (mysql_stmt_store_result(raw) != 0) if (mysql_stmt_store_result(raw) != 0)
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
@ -78,9 +76,9 @@ std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() {
std::unique_ptr<MYSQL_RES, ResDeleter> mt(meta); std::unique_ptr<MYSQL_RES, ResDeleter> mt(meta);
unsigned int numColumns = mysql_num_fields(meta); unsigned int numColumns = mysql_num_fields(meta);
MYSQL_BIND bind[numColumns]; MYSQL_BIND bind[numColumns];
memset(bind, 0, sizeof(bind)); std::memset(bind, 0, sizeof(bind));
std::vector<std::string> line(numColumns); std::vector<std::any> line(numColumns);
std::vector<long unsigned int> lengths(numColumns); std::vector<long unsigned int> lengths(numColumns);
for (unsigned int i = 0; i < numColumns; ++i) { for (unsigned int i = 0; i < numColumns; ++i) {
MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i); MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i);
@ -88,14 +86,32 @@ std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() {
switch (field->type) { switch (field->type) {
case MYSQL_TYPE_STRING: case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING: case MYSQL_TYPE_VAR_STRING:
case MYSQL_TYPE_VARCHAR: case MYSQL_TYPE_VARCHAR: {
line[i] = std::string();
std::string& str = std::any_cast<std::string&>(line[i]);
str.resize(field->length);
bind[i].buffer = str.data();
} break;
case MYSQL_TYPE_TINY:
line[i] = uint8_t{0};
bind[i].buffer = &std::any_cast<std::string&>(line[i]);
break;
case MYSQL_TYPE_SHORT:
line[i] = uint16_t{0};
bind[i].buffer = &std::any_cast<std::string&>(line[i]);
break;
case MYSQL_TYPE_LONG:
line[i] = uint32_t{0};
bind[i].buffer = &std::any_cast<std::string&>(line[i]);
break;
case MYSQL_TYPE_LONGLONG:
line[i] = uint64_t{0};
bind[i].buffer = &std::any_cast<std::string&>(line[i]);
break; break;
default: default:
throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type)); throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type));
} }
line[i].resize(field->length);
bind[i].buffer_type = field->type; bind[i].buffer_type = field->type;
bind[i].buffer = line[i].data();
bind[i].buffer_length = field->length; bind[i].buffer_length = field->length;
bind[i].length = &lengths[i]; bind[i].length = &lengths[i];
} }
@ -103,11 +119,32 @@ std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() {
if (mysql_stmt_bind_result(raw, bind) != 0) if (mysql_stmt_bind_result(raw, bind) != 0)
throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw)); throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw));
std::vector<std::vector<std::string>> result; std::vector<std::vector<std::any>> result;
while (mysql_stmt_fetch(raw) == 0) { while (mysql_stmt_fetch(raw) == 0) {
std::vector<std::string>& row = result.emplace_back(numColumns); std::vector<std::any>& row = result.emplace_back(numColumns);
for (unsigned int i = 0; i < numColumns; ++i) for (unsigned int i = 0; i < numColumns; ++i) {
row[i] = std::string(line[i].data(), lengths[i]); switch (bind[i].buffer_type) {
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
case MYSQL_TYPE_VARCHAR: {
row[i] = std::string(std::any_cast<const std::string&>(line[i]).data(), lengths[i]);
} break;
case MYSQL_TYPE_TINY:
row[i] = std::any_cast<uint8_t>(line[i]);
break;
case MYSQL_TYPE_SHORT:
row[i] = std::any_cast<uint16_t>(line[i]);
break;
case MYSQL_TYPE_LONG:
row[i] = std::any_cast<uint32_t>(line[i]);
break;
case MYSQL_TYPE_LONGLONG:
row[i] = std::any_cast<uint64_t>(line[i]);
break;
default:
throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(bind[i].buffer_type));
}
}
} }
return result; return result;

View File

@ -3,7 +3,10 @@
#pragma once #pragma once
#include <cstring>
#include <vector> #include <vector>
#include <tuple>
#include <any>
#include "mysql.h" #include "mysql.h"
@ -18,7 +21,7 @@ public:
void bind(void* value, enum_field_types type, bool usigned = false); void bind(void* value, enum_field_types type, bool usigned = false);
void execute(); void execute();
std::vector<std::vector<std::string>> fetchResult(); std::vector<std::vector<std::any>> fetchResult();
private: private:
std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt; std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt;

View File

@ -33,24 +33,34 @@ void Handler::Login::handle(Request& request) {
try { try {
success = server->validatePassword(login, password); success = server->validatePassword(login, password);
} catch (const DBInterface::NoLogin& e) { } catch (const DBInterface::NoLogin& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; std::cerr << "Exception on logging in:\n\t" << e.what() << std::endl;
return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; std::cerr << "Exception on logging in:\n\t" << e.what() << std::endl;
return error(request, Result::unknownError, Response::Status::internalError); return error(request, Result::unknownError, Response::Status::internalError);
} catch (...) { } catch (...) {
std::cerr << "Unknown exception on registration" << std::endl; std::cerr << "Unknown exception on ogging in" << std::endl;
return error(request, Result::unknownError, Response::Status::internalError); return error(request, Result::unknownError, Response::Status::internalError);
} }
if (!success) if (!success)
return error(request, Result::noLogin, Response::Status::badRequest); return error(request, Result::noLogin, Response::Status::badRequest);
//TODO opening the session
Response& res = request.createResponse();
nlohmann::json body = nlohmann::json::object(); nlohmann::json body = nlohmann::json::object();
body["result"] = Result::success; body["result"] = Result::success;
try {
Session& session = server->openSession(login);
body["accessToken"] = session.getAccessToken();
body["renewToken"] = session.getRenewToken();
} catch (const std::exception& e) {
std::cerr << "Exception on opening a session:\n\t" << e.what() << std::endl;
return error(request, Result::unknownError, Response::Status::internalError);
} catch (...) {
std::cerr << "Unknown exception on opening a session" << std::endl;
return error(request, Result::unknownError, Response::Status::internalError);
}
Response& res = request.createResponse();
res.setBody(body); res.setBody(body);
res.send(); res.send();
} }

View File

@ -1,11 +1,13 @@
set(HEADERS set(HEADERS
server.h server.h
router.h router.h
session.h
) )
set(SOURCES set(SOURCES
server.cpp server.cpp
router.cpp router.cpp
session.cpp
) )
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})

View File

@ -12,7 +12,7 @@
constexpr const char* pepper = "well, not much of a secret, huh?"; constexpr const char* pepper = "well, not much of a secret, huh?";
constexpr uint8_t currentDbVesion = 1; constexpr uint8_t currentDbVesion = 1;
constexpr const char* randomChars = "0123456789abcdef"; constexpr const char* randomChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
constexpr uint8_t saltSize = 16; constexpr uint8_t saltSize = 16;
constexpr uint8_t hashSize = 32; constexpr uint8_t hashSize = 32;
constexpr uint8_t hashParallel = 1; constexpr uint8_t hashParallel = 1;
@ -24,7 +24,8 @@ Server::Server():
requestCount(0), requestCount(0),
serverName(std::nullopt), serverName(std::nullopt),
router(), router(),
db() db(),
sessions()
{ {
std::cout << "Startig pica..." << std::endl; std::cout << "Startig pica..." << std::endl;
@ -78,7 +79,7 @@ void Server::handleRequest(std::unique_ptr<Request> request) {
std::string Server::generateRandomString(std::size_t length) { std::string Server::generateRandomString(std::size_t length) {
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_int_distribution<uint8_t> distribution(0, std::strlen(randomChars)); std::uniform_int_distribution<uint8_t> distribution(0, std::strlen(randomChars) - 1);
std::string result(length, 0); std::string result(length, 0);
for (size_t i = 0; i < length; ++i) for (size_t i = 0; i < length; ++i)
@ -125,3 +126,12 @@ bool Server::validatePassword(const std::string& login, const std::string& passw
throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result)); throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result));
} }
} }
Session& Server::openSession(const std::string& login) {
std::string accessToken = generateRandomString(32);
std::string renewToken = generateRandomString(32);
unsigned int sessionId = db->createSession(login, accessToken, renewToken);
std::unique_ptr<Session>& session = sessions[accessToken] = std::make_unique<Session>(sessionId, accessToken, renewToken);
return *session.get();
}

View File

@ -10,6 +10,7 @@
#include <string_view> #include <string_view>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include <fcgiapp.h> #include <fcgiapp.h>
#include <fcgio.h> #include <fcgio.h>
@ -20,6 +21,7 @@
#include "request/request.h" #include "request/request.h"
#include "response/response.h" #include "response/response.h"
#include "router.h" #include "router.h"
#include "session.h"
#include "database/dbinterface.h" #include "database/dbinterface.h"
#include "utils/helpers.h" #include "utils/helpers.h"
#include "config.h" #include "config.h"
@ -33,15 +35,19 @@ public:
unsigned int registerAccount(const std::string& login, const std::string& password); unsigned int registerAccount(const std::string& login, const std::string& password);
bool validatePassword(const std::string& login, const std::string& password); bool validatePassword(const std::string& login, const std::string& password);
Session& openSession(const std::string& login);
private: private:
void handleRequest(std::unique_ptr<Request> request); void handleRequest(std::unique_ptr<Request> request);
static std::string generateRandomString(std::size_t length); static std::string generateRandomString(std::size_t length);
private: private:
using Sessions = std::map<std::string, std::unique_ptr<Session>>;
bool terminating; bool terminating;
uint64_t requestCount; uint64_t requestCount;
std::optional<std::string> serverName; std::optional<std::string> serverName;
Router router; Router router;
std::unique_ptr<DBInterface> db; std::unique_ptr<DBInterface> db;
Sessions sessions;
}; };

18
server/session.cpp Normal file
View File

@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#include "session.h"
Session::Session(unsigned int id, const std::string& access, const std::string& renew):
id(id),
access(access),
renew(renew)
{}
std::string Session::getAccessToken() const {
return access;
}
std::string Session::getRenewToken() const {
return renew;
}

19
server/session.h Normal file
View File

@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#pragma once
#include <string>
class Session {
public:
Session(unsigned int id, const std::string& access, const std::string& renew);
std::string getAccessToken() const;
std::string getRenewToken() const;
private:
unsigned int id;
std::string access;
std::string renew;
};