diff --git a/CMakeLists.txt b/CMakeLists.txt index bca2822..183ff6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ message("Compile options: " ${COMPILE_OPTIONS_STRING}) find_package(nlohmann_json REQUIRED) find_package(FCGI REQUIRED) +find_package(Argon2 REQUIRED) add_executable(${PROJECT_NAME} main.cpp) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -59,6 +60,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE FCGI::FCGI FCGI::FCGI++ nlohmann_json::nlohmann_json + Argon2::Argon2 ) install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/README.md b/README.md index e65dfcc..4c48ad8 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ - fcgi - nlohmann_json - mariadb-client +- argon2 ### Building diff --git a/cmake/FindArgon2.cmake b/cmake/FindArgon2.cmake new file mode 100644 index 0000000..75d70a6 --- /dev/null +++ b/cmake/FindArgon2.cmake @@ -0,0 +1,26 @@ +find_library(Argon2_LIBRARIES argon2) +find_path(Argon2_INCLUDE_DIR argon2.h) + +if (Argon2_LIBRARIES AND Argon2_INCLUDE_DIR) + set(Argon2_FOUND TRUE) +endif() + +if (Argon2_FOUND) + add_library(Argon2::Argon2 SHARED IMPORTED) + set_target_properties(Argon2::Argon2 PROPERTIES + IMPORTED_LOCATION "${Argon2_LIBRARIES}" + INTERFACE_LINK_LIBRARIES "${Argon2_LIBRARIES}" + INTERFACE_INCLUDE_DIRECTORIES ${Argon2_INCLUDE_DIR} + ) + + if (NOT Argon2_FIND_QUIETLY) + message(STATUS "Found Argon2 includes: ${Argon2_INCLUDE_DIR}") + message(STATUS "Found Argon2 library: ${Argon2_LIBRARIES}") + endif () +else () + if (Argon2_FIND_REQUIRED) + message(FATAL_ERROR "Could NOT find Argon2 development files") + endif () +endif () + + diff --git a/database/dbinterface.h b/database/dbinterface.h index 839672d..ea03191 100644 --- a/database/dbinterface.h +++ b/database/dbinterface.h @@ -36,6 +36,8 @@ public: virtual uint8_t getVersion() = 0; virtual void setVersion(uint8_t version) = 0; + virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; + protected: DBInterface(Type type); diff --git a/database/migrations/m0.sql b/database/migrations/m0.sql index 53b7028..e52e9ef 100644 --- a/database/migrations/m0.sql +++ b/database/migrations/m0.sql @@ -17,17 +17,36 @@ CREATE TABLE IF NOT EXISTS accounts ( `login` VARCHAR(256) UNIQUE NOT NULL, `nick` VARCHAR(256), `type` INTEGER UNSIGNED NOT NULL, - `password` VARCHAR(64), - `salt` VARCHAR(32), - `role` INTEGER UNSIGNED NOT NULL, - `created` TIMESTAMP DEFAULT UTC_TIMESTAMP(), + `password` VARCHAR(128), + `created` TIMESTAMP DEFAULT UTC_TIMESTAMP() +); +--creating role bindings table +CREATE TABLE IF NOT EXISTS roleBindings ( + `account` INTEGER UNSIGNED NOT NULL, + `role` INTEGER UNSIGNED NOT NULL, + + PRIMARY KEY (account, role), + FOREIGN KEY (account) REFERENCES accounts(id), FOREIGN KEY (role) REFERENCES roles(id) ); +--creating sessings table +CREATE TABLE IF NOT EXISTS sessions ( + `id` INTEGER AUTO_INCREMENT PRIMARY KEY, + `owner` INTEGER UNSIGNED NOT NULL, + `started` TIMESTAMP DEFAULT UTC_TIMESTAMP(), + `latest` TIMESTAMP DEFAULT UTC_TIMESTAMP(), + `salt` CHAR(16), + `persist` BOOLEAN NOT NULL, + + FOREIGN KEY (owner) REFERENCES accounts(id) +); + --creating defailt roles INSERT IGNORE INTO roles (`name`) -VALUES ('root'); +VALUES ('root'), + ('default'); --inserting initial version INSERT INTO system (`key`, `value`) VALUES ('version', '0'); diff --git a/database/mysql/CMakeLists.txt b/database/mysql/CMakeLists.txt index 64d1920..b1cec73 100644 --- a/database/mysql/CMakeLists.txt +++ b/database/mysql/CMakeLists.txt @@ -1,11 +1,13 @@ set(HEADERS mysql.h statement.h + transaction.h ) set(SOURCES mysql.cpp statement.cpp + transaction.cpp ) find_package(MariaDB REQUIRED) diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index 0532f7c..5d781f7 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -9,8 +9,13 @@ #include "mysqld_error.h" #include "statement.h" +#include "transaction.h" +constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'"; constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'"; +constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)"; +constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id"; +constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?"; static const std::filesystem::path buildSQLPath = "database"; @@ -34,7 +39,6 @@ MySQL::~MySQL() { mysql_close(&connection); } - void MySQL::connect(const std::string& path) { if (state != State::disconnected) return; @@ -138,7 +142,7 @@ void MySQL::executeFile(const std::filesystem::path& relativePath) { uint8_t MySQL::getVersion() { MYSQL* con = &connection; - int result = mysql_query(con, "SELECT value FROM system WHERE `key` = 'version'"); + int result = mysql_query(con, versionQuery); if (result != 0) { unsigned int errcode = mysql_errno(con); @@ -216,3 +220,48 @@ std::optional MySQL::getComment(std::string& string) { return std::nullopt; } +unsigned int MySQL::registerAccount(const std::string& login, const std::string& hash) { + //TODO validate filed lengths! + MYSQL* con = &connection; + MySQL::Transaction txn(con); + + Statement addAcc(con, registerQuery); + + std::string l = login; //I hate copying just to please this horible API + std::string h = hash; + addAcc.bind(l.data(), MYSQL_TYPE_STRING); + addAcc.bind(h.data(), MYSQL_TYPE_STRING); + addAcc.execute(); + + unsigned int id = lastInsertedId(); + static std::string defaultRole("default"); + + Statement addRole(con, assignRoleQuery); + addRole.bind(&id, MYSQL_TYPE_LONG, true); + addRole.bind(defaultRole.data(), MYSQL_TYPE_STRING); + addRole.execute(); + + txn.commit(); + return id; +} + +unsigned int MySQL::lastInsertedId() { + MYSQL* con = &connection; + int result = mysql_query(con, lastIdQuery); + + if (result != 0) + throw std::runtime_error(std::string("Error executing last inserted id: ") + mysql_error(con)); + + std::unique_ptr res(mysql_store_result(con)); + if (!res) + throw std::runtime_error(std::string("Querying last inserted id returned no result: ") + mysql_error(con)); + + MYSQL_ROW row = mysql_fetch_row(res.get()); + if (row) + return std::stoi(row[0]); + else + throw std::runtime_error(std::string("Querying last inserted id returned no rows")); +} + + + diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index 007c985..4246e51 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -14,6 +14,7 @@ class MySQL : public DBInterface { class Statement; + class Transaction; public: MySQL(); ~MySQL() override; @@ -27,9 +28,12 @@ public: uint8_t getVersion() override; void setVersion(uint8_t version) override; + unsigned int registerAccount(const std::string& login, const std::string& hash) override; + private: void executeFile(const std::filesystem::path& relativePath); static std::optional getComment(std::string& string); + unsigned int lastInsertedId(); protected: MYSQL connection; diff --git a/database/mysql/statement.cpp b/database/mysql/statement.cpp index 44c7592..5cbf482 100644 --- a/database/mysql/statement.cpp +++ b/database/mysql/statement.cpp @@ -9,15 +9,14 @@ static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME); MySQL::Statement::Statement(MYSQL* connection, const char* statement): stmt(mysql_stmt_init(connection)), - param(), - lengths() + param() { int result = mysql_stmt_prepare(stmt.get(), statement, strlen(statement)); if (result != 0) throw std::runtime_error(std::string("Error preparing statement: ") + mysql_stmt_error(stmt.get())); } -void MySQL::Statement::bind(void* value, enum_field_types type) { +void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) { MYSQL_BIND& result = param.emplace_back(); std::memset(&result, 0, sizeof(result)); @@ -27,13 +26,18 @@ void MySQL::Statement::bind(void* value, enum_field_types type) { switch (type) { case MYSQL_TYPE_STRING: case MYSQL_TYPE_VAR_STRING: - result.length = &lengths.emplace_back(strlen(static_cast(value))); + result.buffer_length = strlen(static_cast(value)); break; case MYSQL_TYPE_DATE: - result.length = &TIME_LENGTH; + result.buffer_length = TIME_LENGTH; + break; + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_LONGLONG: + case MYSQL_TYPE_SHORT: + case MYSQL_TYPE_TINY: + result.is_unsigned = usigned; break; default: - lengths.pop_back(); throw std::runtime_error("Type: " + std::to_string(type) + " is not yet supported in bind"); break; } diff --git a/database/mysql/statement.h b/database/mysql/statement.h index 9cfc129..3dca291 100644 --- a/database/mysql/statement.h +++ b/database/mysql/statement.h @@ -7,7 +7,6 @@ #include "mysql.h" - class MySQL::Statement { struct STMTDeleter { void operator () (MYSQL_STMT* stmt) { @@ -17,11 +16,10 @@ class MySQL::Statement { public: Statement(MYSQL* connection, const char* statement); - void bind(void* value, enum_field_types type); + void bind(void* value, enum_field_types type, bool usigned = false); void execute(); private: std::unique_ptr stmt; std::vector param; - std::vector lengths; }; diff --git a/database/mysql/transaction.cpp b/database/mysql/transaction.cpp new file mode 100644 index 0000000..ff921cc --- /dev/null +++ b/database/mysql/transaction.cpp @@ -0,0 +1,34 @@ +#include "transaction.h" + +MySQL::Transaction::Transaction(MYSQL* connection): + con(connection), + opened(false) +{ + if (mysql_autocommit(con, 0) != 0) + throw std::runtime_error(std::string("Failed to start transaction") + mysql_error(con)); + + opened = true; +} + +MySQL::Transaction::~Transaction() { + if (opened) + abort(); +} + +void MySQL::Transaction::commit() { + if (mysql_commit(con) != 0) + throw std::runtime_error(std::string("Failed to commit transaction") + mysql_error(con)); + + opened = false; + if (mysql_autocommit(con, 1) != 0) + throw std::runtime_error(std::string("Failed to return autocommit") + mysql_error(con)); +} + +void MySQL::Transaction::abort() { + opened = false; + if (mysql_rollback(con) != 0) + throw std::runtime_error(std::string("Failed to rollback transaction") + mysql_error(con)); + + if (mysql_autocommit(con, 1) != 0) + throw std::runtime_error(std::string("Failed to return autocommit") + mysql_error(con)); +} diff --git a/database/mysql/transaction.h b/database/mysql/transaction.h new file mode 100644 index 0000000..87e6713 --- /dev/null +++ b/database/mysql/transaction.h @@ -0,0 +1,16 @@ +#pragma once + +#include "mysql.h" + +class MySQL::Transaction { +public: + Transaction(MYSQL* connection); + ~Transaction(); + + void commit(); + void abort(); + +private: + MYSQL* con; + bool opened; +}; diff --git a/handler/register.cpp b/handler/register.cpp index 4928627..c36e8d4 100644 --- a/handler/register.cpp +++ b/handler/register.cpp @@ -3,20 +3,57 @@ #include "register.h" -Handler::Register::Register(): - Handler("register", Request::Method::post) +#include "server/server.h" + +Handler::Register::Register(Server* server): + Handler("register", Request::Method::post), + server(server) {} void Handler::Register::handle(Request& request) { std::map form = request.getForm(); + std::map::const_iterator itr = form.find("login"); + if (itr == form.end()) + return error(request, Result::noLogin); - std::cout << "Received form:" << std::endl; - for (const auto& pair : form) - std::cout << '\t' << pair.first << ": " << pair.second << std::endl; + const std::string& login = itr->second; + if (login.empty()) + return error(request, Result::emptyLogin); + + //TODO login policies checkup + + itr = form.find("password"); + if (itr == form.end()) + return error(request, Result::noPassword); + + const std::string& password = itr->second; + if (password.empty()) + return error(request, Result::emptyPassword); + + //TODO password policies checkup + + try { + server->registerAccount(login, password); + } catch (const std::exception& e) { + std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; + return error(request, Result::unknownError); + } catch (...) { + std::cerr << "Unknown exception on registration" << std::endl; + return error(request, Result::unknownError); + } Response res(request); nlohmann::json body = nlohmann::json::object(); - body["result"] = "ok"; + body["result"] = Result::success; + + res.setBody(body); + res.send(); +} + +void Handler::Register::error(Request& request, Result result) { + Response res(request); + nlohmann::json body = nlohmann::json::object(); + body["result"] = result; res.setBody(body); res.send(); diff --git a/handler/register.h b/handler/register.h index 047dde3..6aedb2c 100644 --- a/handler/register.h +++ b/handler/register.h @@ -5,12 +5,30 @@ #include "handler.h" +class Server; namespace Handler { class Register : public Handler::Handler { public: - Register(); + Register(Server* server); virtual void handle(Request& request); + enum class Result { + success, + noLogin, + emptyLogin, + loginExists, + loginPolicyViolation, + noPassword, + emptyPassword, + passwordPolicyViolation, + unknownError + }; + +private: + void error(Request& request, Result result); + +private: + Server* server; }; } diff --git a/server/router.cpp b/server/router.cpp index 75fe696..1c7b10d 100644 --- a/server/router.cpp +++ b/server/router.cpp @@ -44,6 +44,7 @@ void Router::route(const std::string& path, std::unique_ptr request) { return handleNotFound(path, std::move(request)); try { + std::cout << "Handling " << path << "..." << std::endl; itr->second->handle(*request.get()); if (request->currentState() != Request::State::responded) diff --git a/server/server.cpp b/server/server.cpp index dad6a3a..6f02de1 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -3,11 +3,20 @@ #include "server.h" +#include + #include "handler/info.h" #include "handler/env.h" #include "handler/register.h" +constexpr const char* pepper = "well, not much of a secret, huh?"; constexpr uint8_t currentDbVesion = 1; +constexpr const char* randomChars = "0123456789abcdef"; +constexpr uint8_t saltSize = 16; +constexpr uint8_t hashSize = 32; +constexpr uint8_t hashParallel = 1; +constexpr uint8_t hashIterations = 2; +constexpr uint32_t hashMemoryCost = 65536; Server::Server(): terminating(false), @@ -29,7 +38,7 @@ Server::Server(): router.addRoute(std::make_unique()); router.addRoute(std::make_unique()); - router.addRoute(std::make_unique()); + router.addRoute(std::make_unique(this)); } Server::~Server() {} @@ -63,3 +72,38 @@ void Server::handleRequest(std::unique_ptr request) { std::string path = request->getPath(serverName.value()); router.route(path.data(), std::move(request)); } + +std::string Server::generateRandomString(std::size_t length) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution distribution(0, std::strlen(randomChars)); + + std::string result(length, 0); + for (size_t i = 0; i < length; ++i) + result[i] = randomChars[distribution(gen)]; + + return result; +} + +unsigned int Server::registerAccount(const std::string& login, const std::string& password) { + std::size_t encSize = argon2_encodedlen( + hashIterations, hashMemoryCost, + hashParallel, saltSize, hashSize, Argon2_id + ); + + std::string hash(encSize, 0); + std::string salt = generateRandomString(saltSize); + std::string spiced = password + pepper; + + int result = argon2id_hash_encoded( + hashIterations, hashMemoryCost, hashParallel, + spiced.data(), spiced.size(), + salt.data(), saltSize, + hashSize, hash.data(), encSize + ); + + if (result != ARGON2_OK) + throw std::runtime_error(std::string("Hashing failed: ") + argon2_error_message(result)); + + return db->registerAccount(login, hash); +} diff --git a/server/server.h b/server/server.h index 13d17d0..5a5fb1b 100644 --- a/server/server.h +++ b/server/server.h @@ -15,7 +15,7 @@ #include #include -#include +#include #include "request/request.h" #include "response/response.h" @@ -31,8 +31,11 @@ public: void run(int socketDescriptor); + unsigned int registerAccount(const std::string& login, const std::string& password); + private: void handleRequest(std::unique_ptr request); + static std::string generateRandomString(std::size_t length); private: bool terminating;