1
0
forked from blue/pica

first what so ever registration

This commit is contained in:
Blue 2023-12-20 19:42:13 -03:00
parent 0c50cfa639
commit 99a9fd507e
Signed by untrusted user: blue
GPG Key ID: 9B203B252A63EE38
17 changed files with 285 additions and 25 deletions

View File

@ -39,6 +39,7 @@ message("Compile options: " ${COMPILE_OPTIONS_STRING})
find_package(nlohmann_json REQUIRED) find_package(nlohmann_json REQUIRED)
find_package(FCGI REQUIRED) find_package(FCGI REQUIRED)
find_package(Argon2 REQUIRED)
add_executable(${PROJECT_NAME} main.cpp) add_executable(${PROJECT_NAME} main.cpp)
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
@ -59,6 +60,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE
FCGI::FCGI FCGI::FCGI
FCGI::FCGI++ FCGI::FCGI++
nlohmann_json::nlohmann_json nlohmann_json::nlohmann_json
Argon2::Argon2
) )
install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) install(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})

View File

@ -7,6 +7,7 @@
- fcgi - fcgi
- nlohmann_json - nlohmann_json
- mariadb-client - mariadb-client
- argon2
### Building ### Building

26
cmake/FindArgon2.cmake Normal file
View File

@ -0,0 +1,26 @@
find_library(Argon2_LIBRARIES argon2)
find_path(Argon2_INCLUDE_DIR argon2.h)
if (Argon2_LIBRARIES AND Argon2_INCLUDE_DIR)
set(Argon2_FOUND TRUE)
endif()
if (Argon2_FOUND)
add_library(Argon2::Argon2 SHARED IMPORTED)
set_target_properties(Argon2::Argon2 PROPERTIES
IMPORTED_LOCATION "${Argon2_LIBRARIES}"
INTERFACE_LINK_LIBRARIES "${Argon2_LIBRARIES}"
INTERFACE_INCLUDE_DIRECTORIES ${Argon2_INCLUDE_DIR}
)
if (NOT Argon2_FIND_QUIETLY)
message(STATUS "Found Argon2 includes: ${Argon2_INCLUDE_DIR}")
message(STATUS "Found Argon2 library: ${Argon2_LIBRARIES}")
endif ()
else ()
if (Argon2_FIND_REQUIRED)
message(FATAL_ERROR "Could NOT find Argon2 development files")
endif ()
endif ()

View File

@ -36,6 +36,8 @@ public:
virtual uint8_t getVersion() = 0; virtual uint8_t getVersion() = 0;
virtual void setVersion(uint8_t version) = 0; virtual void setVersion(uint8_t version) = 0;
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
protected: protected:
DBInterface(Type type); DBInterface(Type type);

View File

@ -17,17 +17,36 @@ CREATE TABLE IF NOT EXISTS accounts (
`login` VARCHAR(256) UNIQUE NOT NULL, `login` VARCHAR(256) UNIQUE NOT NULL,
`nick` VARCHAR(256), `nick` VARCHAR(256),
`type` INTEGER UNSIGNED NOT NULL, `type` INTEGER UNSIGNED NOT NULL,
`password` VARCHAR(64), `password` VARCHAR(128),
`salt` VARCHAR(32), `created` TIMESTAMP DEFAULT UTC_TIMESTAMP()
`role` INTEGER UNSIGNED NOT NULL, );
`created` TIMESTAMP DEFAULT UTC_TIMESTAMP(),
--creating role bindings table
CREATE TABLE IF NOT EXISTS roleBindings (
`account` INTEGER UNSIGNED NOT NULL,
`role` INTEGER UNSIGNED NOT NULL,
PRIMARY KEY (account, role),
FOREIGN KEY (account) REFERENCES accounts(id),
FOREIGN KEY (role) REFERENCES roles(id) FOREIGN KEY (role) REFERENCES roles(id)
); );
--creating sessings table
CREATE TABLE IF NOT EXISTS sessions (
`id` INTEGER AUTO_INCREMENT PRIMARY KEY,
`owner` INTEGER UNSIGNED NOT NULL,
`started` TIMESTAMP DEFAULT UTC_TIMESTAMP(),
`latest` TIMESTAMP DEFAULT UTC_TIMESTAMP(),
`salt` CHAR(16),
`persist` BOOLEAN NOT NULL,
FOREIGN KEY (owner) REFERENCES accounts(id)
);
--creating defailt roles --creating defailt roles
INSERT IGNORE INTO roles (`name`) INSERT IGNORE INTO roles (`name`)
VALUES ('root'); VALUES ('root'),
('default');
--inserting initial version --inserting initial version
INSERT INTO system (`key`, `value`) VALUES ('version', '0'); INSERT INTO system (`key`, `value`) VALUES ('version', '0');

View File

@ -1,11 +1,13 @@
set(HEADERS set(HEADERS
mysql.h mysql.h
statement.h statement.h
transaction.h
) )
set(SOURCES set(SOURCES
mysql.cpp mysql.cpp
statement.cpp statement.cpp
transaction.cpp
) )
find_package(MariaDB REQUIRED) find_package(MariaDB REQUIRED)

View File

@ -9,8 +9,13 @@
#include "mysqld_error.h" #include "mysqld_error.h"
#include "statement.h" #include "statement.h"
#include "transaction.h"
constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'";
constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'"; constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'";
constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)";
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
static const std::filesystem::path buildSQLPath = "database"; static const std::filesystem::path buildSQLPath = "database";
@ -34,7 +39,6 @@ MySQL::~MySQL() {
mysql_close(&connection); mysql_close(&connection);
} }
void MySQL::connect(const std::string& path) { void MySQL::connect(const std::string& path) {
if (state != State::disconnected) if (state != State::disconnected)
return; return;
@ -138,7 +142,7 @@ void MySQL::executeFile(const std::filesystem::path& relativePath) {
uint8_t MySQL::getVersion() { uint8_t MySQL::getVersion() {
MYSQL* con = &connection; MYSQL* con = &connection;
int result = mysql_query(con, "SELECT value FROM system WHERE `key` = 'version'"); int result = mysql_query(con, versionQuery);
if (result != 0) { if (result != 0) {
unsigned int errcode = mysql_errno(con); unsigned int errcode = mysql_errno(con);
@ -216,3 +220,48 @@ std::optional<std::string> MySQL::getComment(std::string& string) {
return std::nullopt; return std::nullopt;
} }
unsigned int MySQL::registerAccount(const std::string& login, const std::string& hash) {
//TODO validate filed lengths!
MYSQL* con = &connection;
MySQL::Transaction txn(con);
Statement addAcc(con, registerQuery);
std::string l = login; //I hate copying just to please this horible API
std::string h = hash;
addAcc.bind(l.data(), MYSQL_TYPE_STRING);
addAcc.bind(h.data(), MYSQL_TYPE_STRING);
addAcc.execute();
unsigned int id = lastInsertedId();
static std::string defaultRole("default");
Statement addRole(con, assignRoleQuery);
addRole.bind(&id, MYSQL_TYPE_LONG, true);
addRole.bind(defaultRole.data(), MYSQL_TYPE_STRING);
addRole.execute();
txn.commit();
return id;
}
unsigned int MySQL::lastInsertedId() {
MYSQL* con = &connection;
int result = mysql_query(con, lastIdQuery);
if (result != 0)
throw std::runtime_error(std::string("Error executing last inserted id: ") + mysql_error(con));
std::unique_ptr<MYSQL_RES, ResDeleter> res(mysql_store_result(con));
if (!res)
throw std::runtime_error(std::string("Querying last inserted id returned no result: ") + mysql_error(con));
MYSQL_ROW row = mysql_fetch_row(res.get());
if (row)
return std::stoi(row[0]);
else
throw std::runtime_error(std::string("Querying last inserted id returned no rows"));
}

View File

@ -14,6 +14,7 @@
class MySQL : public DBInterface { class MySQL : public DBInterface {
class Statement; class Statement;
class Transaction;
public: public:
MySQL(); MySQL();
~MySQL() override; ~MySQL() override;
@ -27,9 +28,12 @@ public:
uint8_t getVersion() override; uint8_t getVersion() override;
void setVersion(uint8_t version) override; void setVersion(uint8_t version) override;
unsigned int registerAccount(const std::string& login, const std::string& hash) override;
private: private:
void executeFile(const std::filesystem::path& relativePath); void executeFile(const std::filesystem::path& relativePath);
static std::optional<std::string> getComment(std::string& string); static std::optional<std::string> getComment(std::string& string);
unsigned int lastInsertedId();
protected: protected:
MYSQL connection; MYSQL connection;

View File

@ -9,15 +9,14 @@ static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME);
MySQL::Statement::Statement(MYSQL* connection, const char* statement): MySQL::Statement::Statement(MYSQL* connection, const char* statement):
stmt(mysql_stmt_init(connection)), stmt(mysql_stmt_init(connection)),
param(), param()
lengths()
{ {
int result = mysql_stmt_prepare(stmt.get(), statement, strlen(statement)); int result = mysql_stmt_prepare(stmt.get(), statement, strlen(statement));
if (result != 0) if (result != 0)
throw std::runtime_error(std::string("Error preparing statement: ") + mysql_stmt_error(stmt.get())); throw std::runtime_error(std::string("Error preparing statement: ") + mysql_stmt_error(stmt.get()));
} }
void MySQL::Statement::bind(void* value, enum_field_types type) { void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) {
MYSQL_BIND& result = param.emplace_back(); MYSQL_BIND& result = param.emplace_back();
std::memset(&result, 0, sizeof(result)); std::memset(&result, 0, sizeof(result));
@ -27,13 +26,18 @@ void MySQL::Statement::bind(void* value, enum_field_types type) {
switch (type) { switch (type) {
case MYSQL_TYPE_STRING: case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING: case MYSQL_TYPE_VAR_STRING:
result.length = &lengths.emplace_back(strlen(static_cast<char*>(value))); result.buffer_length = strlen(static_cast<char*>(value));
break; break;
case MYSQL_TYPE_DATE: case MYSQL_TYPE_DATE:
result.length = &TIME_LENGTH; result.buffer_length = TIME_LENGTH;
break;
case MYSQL_TYPE_LONG:
case MYSQL_TYPE_LONGLONG:
case MYSQL_TYPE_SHORT:
case MYSQL_TYPE_TINY:
result.is_unsigned = usigned;
break; break;
default: default:
lengths.pop_back();
throw std::runtime_error("Type: " + std::to_string(type) + " is not yet supported in bind"); throw std::runtime_error("Type: " + std::to_string(type) + " is not yet supported in bind");
break; break;
} }

View File

@ -7,7 +7,6 @@
#include "mysql.h" #include "mysql.h"
class MySQL::Statement { class MySQL::Statement {
struct STMTDeleter { struct STMTDeleter {
void operator () (MYSQL_STMT* stmt) { void operator () (MYSQL_STMT* stmt) {
@ -17,11 +16,10 @@ class MySQL::Statement {
public: public:
Statement(MYSQL* connection, const char* statement); Statement(MYSQL* connection, const char* statement);
void bind(void* value, enum_field_types type); void bind(void* value, enum_field_types type, bool usigned = false);
void execute(); void execute();
private: private:
std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt; std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt;
std::vector<MYSQL_BIND> param; std::vector<MYSQL_BIND> param;
std::vector<uint64_t> lengths;
}; };

View File

@ -0,0 +1,34 @@
#include "transaction.h"
MySQL::Transaction::Transaction(MYSQL* connection):
con(connection),
opened(false)
{
if (mysql_autocommit(con, 0) != 0)
throw std::runtime_error(std::string("Failed to start transaction") + mysql_error(con));
opened = true;
}
MySQL::Transaction::~Transaction() {
if (opened)
abort();
}
void MySQL::Transaction::commit() {
if (mysql_commit(con) != 0)
throw std::runtime_error(std::string("Failed to commit transaction") + mysql_error(con));
opened = false;
if (mysql_autocommit(con, 1) != 0)
throw std::runtime_error(std::string("Failed to return autocommit") + mysql_error(con));
}
void MySQL::Transaction::abort() {
opened = false;
if (mysql_rollback(con) != 0)
throw std::runtime_error(std::string("Failed to rollback transaction") + mysql_error(con));
if (mysql_autocommit(con, 1) != 0)
throw std::runtime_error(std::string("Failed to return autocommit") + mysql_error(con));
}

View File

@ -0,0 +1,16 @@
#pragma once
#include "mysql.h"
class MySQL::Transaction {
public:
Transaction(MYSQL* connection);
~Transaction();
void commit();
void abort();
private:
MYSQL* con;
bool opened;
};

View File

@ -3,20 +3,57 @@
#include "register.h" #include "register.h"
Handler::Register::Register(): #include "server/server.h"
Handler("register", Request::Method::post)
Handler::Register::Register(Server* server):
Handler("register", Request::Method::post),
server(server)
{} {}
void Handler::Register::handle(Request& request) { void Handler::Register::handle(Request& request) {
std::map form = request.getForm(); std::map form = request.getForm();
std::map<std::string, std::string>::const_iterator itr = form.find("login");
if (itr == form.end())
return error(request, Result::noLogin);
std::cout << "Received form:" << std::endl; const std::string& login = itr->second;
for (const auto& pair : form) if (login.empty())
std::cout << '\t' << pair.first << ": " << pair.second << std::endl; return error(request, Result::emptyLogin);
//TODO login policies checkup
itr = form.find("password");
if (itr == form.end())
return error(request, Result::noPassword);
const std::string& password = itr->second;
if (password.empty())
return error(request, Result::emptyPassword);
//TODO password policies checkup
try {
server->registerAccount(login, password);
} catch (const std::exception& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
return error(request, Result::unknownError);
} catch (...) {
std::cerr << "Unknown exception on registration" << std::endl;
return error(request, Result::unknownError);
}
Response res(request); Response res(request);
nlohmann::json body = nlohmann::json::object(); nlohmann::json body = nlohmann::json::object();
body["result"] = "ok"; body["result"] = Result::success;
res.setBody(body);
res.send();
}
void Handler::Register::error(Request& request, Result result) {
Response res(request);
nlohmann::json body = nlohmann::json::object();
body["result"] = result;
res.setBody(body); res.setBody(body);
res.send(); res.send();

View File

@ -5,12 +5,30 @@
#include "handler.h" #include "handler.h"
class Server;
namespace Handler { namespace Handler {
class Register : public Handler::Handler { class Register : public Handler::Handler {
public: public:
Register(); Register(Server* server);
virtual void handle(Request& request); virtual void handle(Request& request);
enum class Result {
success,
noLogin,
emptyLogin,
loginExists,
loginPolicyViolation,
noPassword,
emptyPassword,
passwordPolicyViolation,
unknownError
};
private:
void error(Request& request, Result result);
private:
Server* server;
}; };
} }

View File

@ -44,6 +44,7 @@ void Router::route(const std::string& path, std::unique_ptr<Request> request) {
return handleNotFound(path, std::move(request)); return handleNotFound(path, std::move(request));
try { try {
std::cout << "Handling " << path << "..." << std::endl;
itr->second->handle(*request.get()); itr->second->handle(*request.get());
if (request->currentState() != Request::State::responded) if (request->currentState() != Request::State::responded)

View File

@ -3,11 +3,20 @@
#include "server.h" #include "server.h"
#include <random>
#include "handler/info.h" #include "handler/info.h"
#include "handler/env.h" #include "handler/env.h"
#include "handler/register.h" #include "handler/register.h"
constexpr const char* pepper = "well, not much of a secret, huh?";
constexpr uint8_t currentDbVesion = 1; constexpr uint8_t currentDbVesion = 1;
constexpr const char* randomChars = "0123456789abcdef";
constexpr uint8_t saltSize = 16;
constexpr uint8_t hashSize = 32;
constexpr uint8_t hashParallel = 1;
constexpr uint8_t hashIterations = 2;
constexpr uint32_t hashMemoryCost = 65536;
Server::Server(): Server::Server():
terminating(false), terminating(false),
@ -29,7 +38,7 @@ Server::Server():
router.addRoute(std::make_unique<Handler::Info>()); router.addRoute(std::make_unique<Handler::Info>());
router.addRoute(std::make_unique<Handler::Env>()); router.addRoute(std::make_unique<Handler::Env>());
router.addRoute(std::make_unique<Handler::Register>()); router.addRoute(std::make_unique<Handler::Register>(this));
} }
Server::~Server() {} Server::~Server() {}
@ -63,3 +72,38 @@ void Server::handleRequest(std::unique_ptr<Request> request) {
std::string path = request->getPath(serverName.value()); std::string path = request->getPath(serverName.value());
router.route(path.data(), std::move(request)); router.route(path.data(), std::move(request));
} }
std::string Server::generateRandomString(std::size_t length) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint8_t> distribution(0, std::strlen(randomChars));
std::string result(length, 0);
for (size_t i = 0; i < length; ++i)
result[i] = randomChars[distribution(gen)];
return result;
}
unsigned int Server::registerAccount(const std::string& login, const std::string& password) {
std::size_t encSize = argon2_encodedlen(
hashIterations, hashMemoryCost,
hashParallel, saltSize, hashSize, Argon2_id
);
std::string hash(encSize, 0);
std::string salt = generateRandomString(saltSize);
std::string spiced = password + pepper;
int result = argon2id_hash_encoded(
hashIterations, hashMemoryCost, hashParallel,
spiced.data(), spiced.size(),
salt.data(), saltSize,
hashSize, hash.data(), encSize
);
if (result != ARGON2_OK)
throw std::runtime_error(std::string("Hashing failed: ") + argon2_error_message(result));
return db->registerAccount(login, hash);
}

View File

@ -15,7 +15,7 @@
#include <fcgio.h> #include <fcgio.h>
#include <stdint.h> #include <stdint.h>
#include <nlohmann/json.hpp> #include <argon2.h>
#include "request/request.h" #include "request/request.h"
#include "response/response.h" #include "response/response.h"
@ -31,8 +31,11 @@ public:
void run(int socketDescriptor); void run(int socketDescriptor);
unsigned int registerAccount(const std::string& login, const std::string& password);
private: private:
void handleRequest(std::unique_ptr<Request> request); void handleRequest(std::unique_ptr<Request> request);
static std::string generateRandomString(std::size_t length);
private: private:
bool terminating; bool terminating;