From 534c28222655ac653796bbc3253b2b3f98cacea9 Mon Sep 17 00:00:00 2001 From: blue Date: Fri, 22 Dec 2023 20:25:20 -0300 Subject: [PATCH] password hash cheching --- database/CMakeLists.txt | 2 + database/dbinterface.h | 6 +++ database/exceptions.cpp | 20 ++++++++++ database/exceptions.h | 26 +++++++++++++ database/mysql/mysql.cpp | 32 +++++++++++---- database/mysql/mysql.h | 9 +++++ database/mysql/statement.cpp | 71 +++++++++++++++++++++++++++++++--- database/mysql/statement.h | 1 + database/mysql/transaction.cpp | 3 ++ database/mysql/transaction.h | 3 ++ handler/CMakeLists.txt | 2 + handler/env.cpp | 2 +- handler/env.h | 4 +- handler/info.cpp | 2 +- handler/login.cpp | 65 +++++++++++++++++++++++++++++++ handler/login.h | 32 +++++++++++++++ handler/register.cpp | 24 +++++++----- handler/register.h | 6 +-- request/request.cpp | 69 +++++++++++++++++---------------- request/request.h | 18 +++++---- response/response.cpp | 27 +++++++++++-- response/response.h | 15 +++++-- server/router.cpp | 14 +++---- server/server.cpp | 20 +++++++++- server/server.h | 1 + 25 files changed, 390 insertions(+), 84 deletions(-) create mode 100644 database/exceptions.cpp create mode 100644 database/exceptions.h create mode 100644 handler/login.cpp create mode 100644 handler/login.h diff --git a/database/CMakeLists.txt b/database/CMakeLists.txt index 42941d6..a0adbc9 100644 --- a/database/CMakeLists.txt +++ b/database/CMakeLists.txt @@ -1,9 +1,11 @@ set(HEADERS dbinterface.h + exceptions.h ) set(SOURCES dbinterface.cpp + exceptions.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/database/dbinterface.h b/database/dbinterface.h index ea03191..5f018e8 100644 --- a/database/dbinterface.h +++ b/database/dbinterface.h @@ -26,6 +26,11 @@ public: const Type type; + class Duplicate; + class DuplicateLogin; + class EmptyResult; + class NoLogin; + public: virtual void connect(const std::string& path) = 0; virtual void disconnect() = 0; @@ -37,6 +42,7 @@ public: virtual void setVersion(uint8_t version) = 0; virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; + virtual std::string getAccountHash(const std::string& login) = 0; protected: DBInterface(Type type); diff --git a/database/exceptions.cpp b/database/exceptions.cpp new file mode 100644 index 0000000..5d555e2 --- /dev/null +++ b/database/exceptions.cpp @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "exceptions.h" + +DBInterface::Duplicate::Duplicate(const std::string& text): + std::runtime_error(text) +{} + +DBInterface::DuplicateLogin::DuplicateLogin(const std::string& text): + Duplicate(text) +{} + +DBInterface::EmptyResult::EmptyResult(const std::string& text): + std::runtime_error(text) +{} + +DBInterface::NoLogin::NoLogin(const std::string& text): + EmptyResult(text) +{} diff --git a/database/exceptions.h b/database/exceptions.h new file mode 100644 index 0000000..8d447e2 --- /dev/null +++ b/database/exceptions.h @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include "dbinterface.h" + +class DBInterface::Duplicate : public std::runtime_error { +public: + explicit Duplicate(const std::string& text); +}; + +class DBInterface::DuplicateLogin : public DBInterface::Duplicate { +public: + explicit DuplicateLogin(const std::string& text); +}; + +class DBInterface::EmptyResult : public std::runtime_error { +public: + explicit EmptyResult(const std::string& text); +}; + +class DBInterface::NoLogin : public DBInterface::EmptyResult { +public: + explicit NoLogin(const std::string& text); +}; diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index 5d781f7..1e40cf4 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -10,21 +10,17 @@ #include "statement.h" #include "transaction.h" +#include "database/exceptions.h" constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'"; constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'"; constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)"; constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id"; constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?"; +constexpr const char* selectHash = "SELECT password FROM accounts where login = ?"; static const std::filesystem::path buildSQLPath = "database"; -struct ResDeleter { - void operator () (MYSQL_RES* res) { - mysql_free_result(res); - } -}; - MySQL::MySQL(): DBInterface(Type::mysql), connection(), @@ -231,7 +227,11 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string& std::string h = hash; addAcc.bind(l.data(), MYSQL_TYPE_STRING); addAcc.bind(h.data(), MYSQL_TYPE_STRING); - addAcc.execute(); + try { + addAcc.execute(); + } catch (const Duplicate& dup) { + throw DuplicateLogin(dup.what()); + } unsigned int id = lastInsertedId(); static std::string defaultRole("default"); @@ -245,6 +245,24 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string& return id; } +std::string MySQL::getAccountHash(const std::string& login) { + std::string l = login; + MYSQL* con = &connection; + + Statement getHash(con, selectHash); + getHash.bind(l.data(), MYSQL_TYPE_STRING); + getHash.execute(); + + std::vector> result = getHash.fetchResult(); + if (result.empty()) + throw NoLogin("Couldn't find login " + l); + + if (result[0].empty()) + throw std::runtime_error("Error with the query \"selectHash\""); + + return result[0][0]; +} + unsigned int MySQL::lastInsertedId() { MYSQL* con = &connection; int result = mysql_query(con, lastIdQuery); diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index 4246e51..75e865d 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -15,6 +15,8 @@ class MySQL : public DBInterface { class Statement; class Transaction; + + public: MySQL(); ~MySQL() override; @@ -29,6 +31,7 @@ public: void setVersion(uint8_t version) override; unsigned int registerAccount(const std::string& login, const std::string& hash) override; + std::string getAccountHash(const std::string& login) override; private: void executeFile(const std::filesystem::path& relativePath); @@ -40,4 +43,10 @@ protected: std::string login; std::string password; std::string database; + +struct ResDeleter { + void operator () (MYSQL_RES* res) { + mysql_free_result(res); + } +}; }; diff --git a/database/mysql/statement.cpp b/database/mysql/statement.cpp index 5cbf482..acfb0e6 100644 --- a/database/mysql/statement.cpp +++ b/database/mysql/statement.cpp @@ -5,6 +5,10 @@ #include +#include "mysqld_error.h" + +#include "database/exceptions.h" + static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME); MySQL::Statement::Statement(MYSQL* connection, const char* statement): @@ -44,11 +48,68 @@ void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) { } void MySQL::Statement::execute() { - int result = mysql_stmt_bind_param(stmt.get(), param.data()); + MYSQL_STMT* raw = stmt.get(); + int result = mysql_stmt_bind_param(raw, param.data()); if (result != 0) - throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(stmt.get())); + throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(raw)); - result = mysql_stmt_execute(stmt.get()); - if (result != 0) - throw std::runtime_error(std::string("Error executing statement: ") + mysql_stmt_error(stmt.get())); + result = mysql_stmt_execute(raw); + if (result != 0) { + int errcode = mysql_stmt_errno(raw); + std::string text = mysql_stmt_error(raw); + switch (errcode) { + case ER_DUP_ENTRY: + throw Duplicate("Error executing statement: " + text); + default: + throw std::runtime_error("Error executing statement: " + text); + } + } } + +std::vector> MySQL::Statement::fetchResult() { + MYSQL_STMT* raw = stmt.get(); + if (mysql_stmt_store_result(raw) != 0) + throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here + + MYSQL_RES* meta = mysql_stmt_result_metadata(raw); + if (meta == nullptr) + throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here + + std::unique_ptr mt(meta); + unsigned int numColumns = mysql_num_fields(meta); + MYSQL_BIND bind[numColumns]; + memset(bind, 0, sizeof(bind)); + + std::vector line(numColumns); + std::vector lengths(numColumns); + for (unsigned int i = 0; i < numColumns; ++i) { + MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i); + + switch (field->type) { + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_VARCHAR: + break; + default: + throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type)); + } + line[i].resize(field->length); + bind[i].buffer_type = field->type; + bind[i].buffer = line[i].data(); + bind[i].buffer_length = field->length; + bind[i].length = &lengths[i]; + } + + if (mysql_stmt_bind_result(raw, bind) != 0) + throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw)); + + std::vector> result; + while (mysql_stmt_fetch(raw) == 0) { + std::vector& row = result.emplace_back(numColumns); + for (unsigned int i = 0; i < numColumns; ++i) + row[i] = std::string(line[i].data(), lengths[i]); + } + + return result; +} + diff --git a/database/mysql/statement.h b/database/mysql/statement.h index 3dca291..3209155 100644 --- a/database/mysql/statement.h +++ b/database/mysql/statement.h @@ -18,6 +18,7 @@ public: void bind(void* value, enum_field_types type, bool usigned = false); void execute(); + std::vector> fetchResult(); private: std::unique_ptr stmt; diff --git a/database/mysql/transaction.cpp b/database/mysql/transaction.cpp index ff921cc..1a9b92c 100644 --- a/database/mysql/transaction.cpp +++ b/database/mysql/transaction.cpp @@ -1,3 +1,6 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + #include "transaction.h" MySQL::Transaction::Transaction(MYSQL* connection): diff --git a/database/mysql/transaction.h b/database/mysql/transaction.h index 87e6713..5a6a15c 100644 --- a/database/mysql/transaction.h +++ b/database/mysql/transaction.h @@ -1,3 +1,6 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + #pragma once #include "mysql.h" diff --git a/handler/CMakeLists.txt b/handler/CMakeLists.txt index 1f58f4f..07b134a 100644 --- a/handler/CMakeLists.txt +++ b/handler/CMakeLists.txt @@ -3,6 +3,7 @@ set(HEADERS info.h env.h register.h + login.h ) set(SOURCES @@ -10,6 +11,7 @@ set(SOURCES info.cpp env.cpp register.cpp + login.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/handler/env.cpp b/handler/env.cpp index a615e17..08910c2 100644 --- a/handler/env.cpp +++ b/handler/env.cpp @@ -11,7 +11,7 @@ void Handler::Env::handle(Request& request) { nlohmann::json body = nlohmann::json::object(); request.printEnvironment(body); - Response res(request); + Response& res = request.createResponse(); res.setBody(body); res.send(); } diff --git a/handler/env.h b/handler/env.h index b4a5ea4..cb361f0 100644 --- a/handler/env.h +++ b/handler/env.h @@ -7,10 +7,10 @@ namespace Handler { -class Env : public Handler::Handler { +class Env : public Handler { public: Env(); - virtual void handle(Request& request); + void handle(Request& request) override; }; } diff --git a/handler/info.cpp b/handler/info.cpp index ba2dfec..3a65cd7 100644 --- a/handler/info.cpp +++ b/handler/info.cpp @@ -8,7 +8,7 @@ Handler::Info::Info(): {} void Handler::Info::handle(Request& request) { - Response res(request); + Response& res = request.createResponse(); nlohmann::json body = nlohmann::json::object(); body["type"] = PROJECT_NAME; body["version"] = PROJECT_VERSION; diff --git a/handler/login.cpp b/handler/login.cpp new file mode 100644 index 0000000..f09ecec --- /dev/null +++ b/handler/login.cpp @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#include "login.h" + +#include "server/server.h" +#include "database/exceptions.h" + +Handler::Login::Login(Server* server): + Handler("login", Request::Method::post), + server(server) +{} + +void Handler::Login::handle(Request& request) { + std::map form = request.getForm(); + std::map::const_iterator itr = form.find("login"); + if (itr == form.end()) + return error(request, Result::noLogin, Response::Status::badRequest); + + const std::string& login = itr->second; + if (login.empty()) + return error(request, Result::emptyLogin, Response::Status::badRequest); + + itr = form.find("password"); + if (itr == form.end()) + return error(request, Result::noPassword, Response::Status::badRequest); + + const std::string& password = itr->second; + if (password.empty()) + return error(request, Result::emptyPassword, Response::Status::badRequest); + + bool success = false; + try { + success = server->validatePassword(login, password); + } catch (const DBInterface::NoLogin& e) { + std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; + return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing + } catch (const std::exception& e) { + std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; + return error(request, Result::unknownError, Response::Status::internalError); + } catch (...) { + std::cerr << "Unknown exception on registration" << std::endl; + return error(request, Result::unknownError, Response::Status::internalError); + } + if (!success) + return error(request, Result::noLogin, Response::Status::badRequest); + + //TODO opening the session + + Response& res = request.createResponse(); + nlohmann::json body = nlohmann::json::object(); + body["result"] = Result::success; + + res.setBody(body); + res.send(); +} + +void Handler::Login::error(Request& request, Result result, Response::Status code) { + Response& res = request.createResponse(code); + nlohmann::json body = nlohmann::json::object(); + body["result"] = result; + + res.setBody(body); + res.send(); +} diff --git a/handler/login.h b/handler/login.h new file mode 100644 index 0000000..ef0986e --- /dev/null +++ b/handler/login.h @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2023 Yury Gubich +// SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include "handler.h" + +class Server; +namespace Handler { + +class Login : public Handler { +public: + Login(Server* server); + void handle(Request& request) override; + + enum class Result { + success, + noLogin, + emptyLogin, + noPassword, + emptyPassword, + unknownError + }; + +private: + void error(Request& request, Result result, Response::Status code); + +private: + Server* server; +}; +} + diff --git a/handler/register.cpp b/handler/register.cpp index c36e8d4..57fefb9 100644 --- a/handler/register.cpp +++ b/handler/register.cpp @@ -4,6 +4,7 @@ #include "register.h" #include "server/server.h" +#include "database/exceptions.h" Handler::Register::Register(Server* server): Handler("register", Request::Method::post), @@ -14,35 +15,38 @@ void Handler::Register::handle(Request& request) { std::map form = request.getForm(); std::map::const_iterator itr = form.find("login"); if (itr == form.end()) - return error(request, Result::noLogin); + return error(request, Result::noLogin, Response::Status::badRequest); const std::string& login = itr->second; if (login.empty()) - return error(request, Result::emptyLogin); + return error(request, Result::emptyLogin, Response::Status::badRequest); //TODO login policies checkup itr = form.find("password"); if (itr == form.end()) - return error(request, Result::noPassword); + return error(request, Result::noPassword, Response::Status::badRequest); const std::string& password = itr->second; if (password.empty()) - return error(request, Result::emptyPassword); + return error(request, Result::emptyPassword, Response::Status::badRequest); //TODO password policies checkup try { server->registerAccount(login, password); + } catch (const DBInterface::DuplicateLogin& e) { + std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; + return error(request, Result::loginExists, Response::Status::conflict); } catch (const std::exception& e) { std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; - return error(request, Result::unknownError); - } catch (...) { + return error(request, Result::unknownError, Response::Status::internalError); + } catch (...) { std::cerr << "Unknown exception on registration" << std::endl; - return error(request, Result::unknownError); + return error(request, Result::unknownError, Response::Status::internalError); } - Response res(request); + Response& res = request.createResponse(); nlohmann::json body = nlohmann::json::object(); body["result"] = Result::success; @@ -50,8 +54,8 @@ void Handler::Register::handle(Request& request) { res.send(); } -void Handler::Register::error(Request& request, Result result) { - Response res(request); +void Handler::Register::error(Request& request, Result result, Response::Status code) { + Response& res = request.createResponse(code); nlohmann::json body = nlohmann::json::object(); body["result"] = result; diff --git a/handler/register.h b/handler/register.h index 6aedb2c..6e36655 100644 --- a/handler/register.h +++ b/handler/register.h @@ -8,10 +8,10 @@ class Server; namespace Handler { -class Register : public Handler::Handler { +class Register : public Handler { public: Register(Server* server); - virtual void handle(Request& request); + void handle(Request& request) override; enum class Result { success, @@ -26,7 +26,7 @@ public: }; private: - void error(Request& request, Result result); + void error(Request& request, Result result, Response::Status code); private: Server* server; diff --git a/request/request.cpp b/request/request.cpp index 2859e00..a3e8701 100644 --- a/request/request.cpp +++ b/request/request.cpp @@ -5,8 +5,6 @@ #include "response/response.h" -constexpr static const char* GET("GET"); - constexpr static const char* REQUEST_METHOD("REQUEST_METHOD"); constexpr static const char* SCRIPT_FILENAME("SCRIPT_FILENAME"); constexpr static const char* SERVER_NAME("SERVER_NAME"); @@ -51,11 +49,15 @@ void Request::terminate() { } } -Request::Method Request::method() const { +std::string_view Request::methodName() const { if (state == State::initial) throw std::runtime_error("An attempt to read request method on not accepted request"); - std::string_view method(FCGX_GetParam(REQUEST_METHOD, raw.envp)); + return FCGX_GetParam(REQUEST_METHOD, raw.envp); +} + +Request::Method Request::method() const { + std::string_view method = methodName(); for (const auto& pair : methods) { if (pair.first == method) return pair.second; @@ -79,17 +81,42 @@ bool Request::wait(int socketDescriptor) { return result; } -OStream Request::getOutputStream(const Response* response) { - validateResponse(response); +OStream Request::getOutputStream() { return OStream(raw.out); } -OStream Request::getErrorStream(const Response* response) { - validateResponse(response); +OStream Request::getErrorStream() { return OStream(raw.err); } -void Request::responseIsComplete(const Response* response) { +Response& Request::createResponse() { + if (state != State::accepted) + throw std::runtime_error("An attempt create response to the request in the wrong state"); + + response = std::unique_ptr(new Response(*this)); + state = State::responding; + + return *response.get(); +} + +Response& Request::createResponse(Response::Status status) { + if (state != State::accepted) + throw std::runtime_error("An attempt create response to the request in the wrong state"); + + response = std::unique_ptr(new Response(*this, status)); + state = State::responding; + + return *response.get(); +} + +uint16_t Request::responseCode() const { + if (state != State::responded) + throw std::runtime_error("An attempt create read response code on the wrong state"); + + return response->statusCode(); +} + +void Request::responseIsComplete() { switch (state) { case State::initial: throw std::runtime_error("An attempt to mark the request as complete, but it wasn't even accepted yet"); @@ -98,10 +125,6 @@ void Request::responseIsComplete(const Response* response) { throw std::runtime_error("An attempt to mark the request as complete, but it wasn't responded"); break; case State::responding: - if (Request::response != response) - throw std::runtime_error("An attempt to mark the request as complete by the different response who actually started responding"); - - Request::response = nullptr; state = State::responded; break; case State::responded: @@ -110,26 +133,6 @@ void Request::responseIsComplete(const Response* response) { } } -void Request::validateResponse(const Response* response) { - switch (state) { - case State::initial: - throw std::runtime_error("An attempt to request stream while the request wasn't even accepted yet"); - break; - case State::accepted: - Request::response = response; - state = State::responding; - break; - case State::responding: - if (Request::response != response) - throw std::runtime_error("Error handling a request: first time one response started replying, then another continued"); - - break; - case State::responded: - throw std::runtime_error("An attempt to request stream on a request that was already done responding"); - break; - } -} - Request::State Request::currentState() const { return state; } diff --git a/request/request.h b/request/request.h index e28e283..dc2ca3c 100644 --- a/request/request.h +++ b/request/request.h @@ -16,9 +16,10 @@ #include "stream/ostream.h" #include "utils/formdecode.h" +#include "response/response.h" -class Response; class Request { + friend class Response; public: enum class State { initial, @@ -43,26 +44,29 @@ public: bool wait(int socketDescriptor); void terminate(); + Response& createResponse(); + Response& createResponse(Response::Status status); + + uint16_t responseCode() const; Method method() const; + std::string_view methodName() const; State currentState() const; bool isFormUrlEncoded() const; unsigned int contentLength() const; std::map getForm() const; - OStream getOutputStream(const Response* response); - OStream getErrorStream(const Response* response); - void responseIsComplete(const Response* response); - std::string getPath(const std::string& serverName) const; std::string getServerName() const; void printEnvironment(std::ostream& out); void printEnvironment(nlohmann::json& out); private: - void validateResponse(const Response* response); + OStream getOutputStream(); + OStream getErrorStream(); + void responseIsComplete(); private: State state; FCGX_Request raw; - const Response* response; + std::unique_ptr response; }; diff --git a/response/response.cpp b/response/response.cpp index 7d3d394..e9a483a 100644 --- a/response/response.cpp +++ b/response/response.cpp @@ -3,10 +3,25 @@ #include "response.h" -constexpr std::array(Response::Status::__size)> statusCodes = { +#include "request/request.h" + +constexpr std::array(Response::Status::__size)> statusCodes = { + 200, + 400, + 401, + 404, + 405, + 409, + 500 +}; + +constexpr std::array(Response::Status::__size)> statuses = { "Status: 200 OK", + "Status: 400 Bad Request", + "Status: 401 Unauthorized", "Status: 404 Not Found", "Status: 405 Method Not Allowed", + "Status: 409 Conflict", "Status: 500 Internal Error" }; @@ -33,9 +48,9 @@ void Response::send() const { // OStream out = status == Status::ok ? // request.getOutputStream() : // request.getErrorStream(); - OStream out = request.getOutputStream(this); + OStream out = request.getOutputStream(); - out << statusCodes[static_cast(status)]; + out << statuses[static_cast(status)]; if (!body.empty()) out << '\n' << contentTypes[static_cast(type)] @@ -43,7 +58,11 @@ void Response::send() const { << '\n' << body; - request.responseIsComplete(this); + request.responseIsComplete(); +} + +uint16_t Response::statusCode() const { + return statusCodes[static_cast(status)]; } void Response::setBody(const std::string& body) { diff --git a/response/response.h b/response/response.h index 946ecbb..bb18c5a 100644 --- a/response/response.h +++ b/response/response.h @@ -9,15 +9,20 @@ #include -#include "request/request.h" #include "stream/ostream.h" +class Request; class Response { + friend class Request; + public: enum class Status { ok, + badRequest, + unauthorized, notFound, methodNotAllowed, + conflict, internalError, __size }; @@ -27,13 +32,17 @@ public: json, __size }; - Response(Request& request); - Response(Request& request, Status status); + + uint16_t statusCode() const; void send() const; void setBody(const std::string& body); void setBody(const nlohmann::json& body); +private: + Response(Request& request); + Response(Request& request, Status status); + private: Request& request; Status status; diff --git a/server/router.cpp b/server/router.cpp index 1c7b10d..93c459c 100644 --- a/server/router.cpp +++ b/server/router.cpp @@ -50,29 +50,29 @@ void Router::route(const std::string& path, std::unique_ptr request) { if (request->currentState() != Request::State::responded) handleInternalError(path, std::runtime_error("handler failed to handle the request"), std::move(request)); else - std::cout << "Success:\t" << path << std::endl; + std::cout << request->responseCode() << '\t' << request->methodName() << '\t' << path << std::endl; } catch (const std::exception& e) { handleInternalError(path, e, std::move(request)); } } void Router::handleNotFound(const std::string& path, std::unique_ptr request) { - Response notFound(*request.get(), Response::Status::notFound); + Response& notFound = request->createResponse(Response::Status::notFound); notFound.setBody(std::string("Path \"") + path + "\" was not found"); notFound.send(); - std::cerr << "Not found:\t" << path << std::endl; + std::cerr << notFound.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl; } void Router::handleInternalError(const std::string& path, const std::exception& exception, std::unique_ptr request) { - Response error(*request.get(), Response::Status::internalError); + Response& error = request->createResponse(Response::Status::internalError); error.setBody(std::string(exception.what())); error.send(); - std::cerr << "Internal error:\t" << path << "\n\t" << exception.what() << std::endl; + std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl; } void Router::handleMethodNotAllowed(const std::string& path, std::unique_ptr request) { - Response error(*request.get(), Response::Status::methodNotAllowed); + Response& error = request->createResponse(Response::Status::methodNotAllowed); error.setBody(std::string("Method not allowed")); error.send(); - std::cerr << "Method not allowed:\t" << path << std::endl; + std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl; } diff --git a/server/server.cpp b/server/server.cpp index 6f02de1..a605012 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -8,6 +8,7 @@ #include "handler/info.h" #include "handler/env.h" #include "handler/register.h" +#include "handler/login.h" constexpr const char* pepper = "well, not much of a secret, huh?"; constexpr uint8_t currentDbVesion = 1; @@ -39,6 +40,7 @@ Server::Server(): router.addRoute(std::make_unique()); router.addRoute(std::make_unique()); router.addRoute(std::make_unique(this)); + router.addRoute(std::make_unique(this)); } Server::~Server() {} @@ -63,7 +65,7 @@ void Server::handleRequest(std::unique_ptr request) { std::cout << "received server name " << serverName.value() << std::endl; } catch (...) { std::cerr << "failed to read server name" << std::endl; - Response error(*request.get(), Response::Status::internalError); + Response& error = request->createResponse(Response::Status::internalError); error.send(); return; } @@ -107,3 +109,19 @@ unsigned int Server::registerAccount(const std::string& login, const std::string return db->registerAccount(login, hash); } + +bool Server::validatePassword(const std::string& login, const std::string& password) { + std::string hash = db->getAccountHash(login); + + std::string spiced = password + pepper; + int result = argon2id_verify(hash.data(), spiced.data(), spiced.size()); + + switch (result) { + case ARGON2_OK: + return true; + case ARGON2_VERIFY_MISMATCH: + return false; + default: + throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result)); + } +} diff --git a/server/server.h b/server/server.h index 5a5fb1b..4291ee9 100644 --- a/server/server.h +++ b/server/server.h @@ -32,6 +32,7 @@ public: void run(int socketDescriptor); unsigned int registerAccount(const std::string& login, const std::string& password); + bool validatePassword(const std::string& login, const std::string& password); private: void handleRequest(std::unique_ptr request);