1
0
forked from blue/pica

password hash cheching

This commit is contained in:
Blue 2023-12-22 20:25:20 -03:00
parent 99a9fd507e
commit 534c282226
Signed by untrusted user: blue
GPG Key ID: 9B203B252A63EE38
25 changed files with 390 additions and 84 deletions

View File

@ -1,9 +1,11 @@
set(HEADERS set(HEADERS
dbinterface.h dbinterface.h
exceptions.h
) )
set(SOURCES set(SOURCES
dbinterface.cpp dbinterface.cpp
exceptions.cpp
) )
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})

View File

@ -26,6 +26,11 @@ public:
const Type type; const Type type;
class Duplicate;
class DuplicateLogin;
class EmptyResult;
class NoLogin;
public: public:
virtual void connect(const std::string& path) = 0; virtual void connect(const std::string& path) = 0;
virtual void disconnect() = 0; virtual void disconnect() = 0;
@ -37,6 +42,7 @@ public:
virtual void setVersion(uint8_t version) = 0; virtual void setVersion(uint8_t version) = 0;
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
virtual std::string getAccountHash(const std::string& login) = 0;
protected: protected:
DBInterface(Type type); DBInterface(Type type);

20
database/exceptions.cpp Normal file
View File

@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#include "exceptions.h"
DBInterface::Duplicate::Duplicate(const std::string& text):
std::runtime_error(text)
{}
DBInterface::DuplicateLogin::DuplicateLogin(const std::string& text):
Duplicate(text)
{}
DBInterface::EmptyResult::EmptyResult(const std::string& text):
std::runtime_error(text)
{}
DBInterface::NoLogin::NoLogin(const std::string& text):
EmptyResult(text)
{}

26
database/exceptions.h Normal file
View File

@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#pragma once
#include "dbinterface.h"
class DBInterface::Duplicate : public std::runtime_error {
public:
explicit Duplicate(const std::string& text);
};
class DBInterface::DuplicateLogin : public DBInterface::Duplicate {
public:
explicit DuplicateLogin(const std::string& text);
};
class DBInterface::EmptyResult : public std::runtime_error {
public:
explicit EmptyResult(const std::string& text);
};
class DBInterface::NoLogin : public DBInterface::EmptyResult {
public:
explicit NoLogin(const std::string& text);
};

View File

@ -10,21 +10,17 @@
#include "statement.h" #include "statement.h"
#include "transaction.h" #include "transaction.h"
#include "database/exceptions.h"
constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'"; constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'";
constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'"; constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'";
constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)"; constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)";
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id"; constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?"; constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
constexpr const char* selectHash = "SELECT password FROM accounts where login = ?";
static const std::filesystem::path buildSQLPath = "database"; static const std::filesystem::path buildSQLPath = "database";
struct ResDeleter {
void operator () (MYSQL_RES* res) {
mysql_free_result(res);
}
};
MySQL::MySQL(): MySQL::MySQL():
DBInterface(Type::mysql), DBInterface(Type::mysql),
connection(), connection(),
@ -231,7 +227,11 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string&
std::string h = hash; std::string h = hash;
addAcc.bind(l.data(), MYSQL_TYPE_STRING); addAcc.bind(l.data(), MYSQL_TYPE_STRING);
addAcc.bind(h.data(), MYSQL_TYPE_STRING); addAcc.bind(h.data(), MYSQL_TYPE_STRING);
addAcc.execute(); try {
addAcc.execute();
} catch (const Duplicate& dup) {
throw DuplicateLogin(dup.what());
}
unsigned int id = lastInsertedId(); unsigned int id = lastInsertedId();
static std::string defaultRole("default"); static std::string defaultRole("default");
@ -245,6 +245,24 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string&
return id; return id;
} }
std::string MySQL::getAccountHash(const std::string& login) {
std::string l = login;
MYSQL* con = &connection;
Statement getHash(con, selectHash);
getHash.bind(l.data(), MYSQL_TYPE_STRING);
getHash.execute();
std::vector<std::vector<std::string>> result = getHash.fetchResult();
if (result.empty())
throw NoLogin("Couldn't find login " + l);
if (result[0].empty())
throw std::runtime_error("Error with the query \"selectHash\"");
return result[0][0];
}
unsigned int MySQL::lastInsertedId() { unsigned int MySQL::lastInsertedId() {
MYSQL* con = &connection; MYSQL* con = &connection;
int result = mysql_query(con, lastIdQuery); int result = mysql_query(con, lastIdQuery);

View File

@ -15,6 +15,8 @@
class MySQL : public DBInterface { class MySQL : public DBInterface {
class Statement; class Statement;
class Transaction; class Transaction;
public: public:
MySQL(); MySQL();
~MySQL() override; ~MySQL() override;
@ -29,6 +31,7 @@ public:
void setVersion(uint8_t version) override; void setVersion(uint8_t version) override;
unsigned int registerAccount(const std::string& login, const std::string& hash) override; unsigned int registerAccount(const std::string& login, const std::string& hash) override;
std::string getAccountHash(const std::string& login) override;
private: private:
void executeFile(const std::filesystem::path& relativePath); void executeFile(const std::filesystem::path& relativePath);
@ -40,4 +43,10 @@ protected:
std::string login; std::string login;
std::string password; std::string password;
std::string database; std::string database;
struct ResDeleter {
void operator () (MYSQL_RES* res) {
mysql_free_result(res);
}
};
}; };

View File

@ -5,6 +5,10 @@
#include <cstring> #include <cstring>
#include "mysqld_error.h"
#include "database/exceptions.h"
static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME); static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME);
MySQL::Statement::Statement(MYSQL* connection, const char* statement): MySQL::Statement::Statement(MYSQL* connection, const char* statement):
@ -44,11 +48,68 @@ void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) {
} }
void MySQL::Statement::execute() { void MySQL::Statement::execute() {
int result = mysql_stmt_bind_param(stmt.get(), param.data()); MYSQL_STMT* raw = stmt.get();
int result = mysql_stmt_bind_param(raw, param.data());
if (result != 0) if (result != 0)
throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(stmt.get())); throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(raw));
result = mysql_stmt_execute(stmt.get()); result = mysql_stmt_execute(raw);
if (result != 0) if (result != 0) {
throw std::runtime_error(std::string("Error executing statement: ") + mysql_stmt_error(stmt.get())); int errcode = mysql_stmt_errno(raw);
std::string text = mysql_stmt_error(raw);
switch (errcode) {
case ER_DUP_ENTRY:
throw Duplicate("Error executing statement: " + text);
default:
throw std::runtime_error("Error executing statement: " + text);
}
}
} }
std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() {
MYSQL_STMT* raw = stmt.get();
if (mysql_stmt_store_result(raw) != 0)
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
MYSQL_RES* meta = mysql_stmt_result_metadata(raw);
if (meta == nullptr)
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
std::unique_ptr<MYSQL_RES, ResDeleter> mt(meta);
unsigned int numColumns = mysql_num_fields(meta);
MYSQL_BIND bind[numColumns];
memset(bind, 0, sizeof(bind));
std::vector<std::string> line(numColumns);
std::vector<long unsigned int> lengths(numColumns);
for (unsigned int i = 0; i < numColumns; ++i) {
MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i);
switch (field->type) {
case MYSQL_TYPE_STRING:
case MYSQL_TYPE_VAR_STRING:
case MYSQL_TYPE_VARCHAR:
break;
default:
throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type));
}
line[i].resize(field->length);
bind[i].buffer_type = field->type;
bind[i].buffer = line[i].data();
bind[i].buffer_length = field->length;
bind[i].length = &lengths[i];
}
if (mysql_stmt_bind_result(raw, bind) != 0)
throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw));
std::vector<std::vector<std::string>> result;
while (mysql_stmt_fetch(raw) == 0) {
std::vector<std::string>& row = result.emplace_back(numColumns);
for (unsigned int i = 0; i < numColumns; ++i)
row[i] = std::string(line[i].data(), lengths[i]);
}
return result;
}

View File

@ -18,6 +18,7 @@ public:
void bind(void* value, enum_field_types type, bool usigned = false); void bind(void* value, enum_field_types type, bool usigned = false);
void execute(); void execute();
std::vector<std::vector<std::string>> fetchResult();
private: private:
std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt; std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt;

View File

@ -1,3 +1,6 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#include "transaction.h" #include "transaction.h"
MySQL::Transaction::Transaction(MYSQL* connection): MySQL::Transaction::Transaction(MYSQL* connection):

View File

@ -1,3 +1,6 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#pragma once #pragma once
#include "mysql.h" #include "mysql.h"

View File

@ -3,6 +3,7 @@ set(HEADERS
info.h info.h
env.h env.h
register.h register.h
login.h
) )
set(SOURCES set(SOURCES
@ -10,6 +11,7 @@ set(SOURCES
info.cpp info.cpp
env.cpp env.cpp
register.cpp register.cpp
login.cpp
) )
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})

View File

@ -11,7 +11,7 @@ void Handler::Env::handle(Request& request) {
nlohmann::json body = nlohmann::json::object(); nlohmann::json body = nlohmann::json::object();
request.printEnvironment(body); request.printEnvironment(body);
Response res(request); Response& res = request.createResponse();
res.setBody(body); res.setBody(body);
res.send(); res.send();
} }

View File

@ -7,10 +7,10 @@
namespace Handler { namespace Handler {
class Env : public Handler::Handler { class Env : public Handler {
public: public:
Env(); Env();
virtual void handle(Request& request); void handle(Request& request) override;
}; };
} }

View File

@ -8,7 +8,7 @@ Handler::Info::Info():
{} {}
void Handler::Info::handle(Request& request) { void Handler::Info::handle(Request& request) {
Response res(request); Response& res = request.createResponse();
nlohmann::json body = nlohmann::json::object(); nlohmann::json body = nlohmann::json::object();
body["type"] = PROJECT_NAME; body["type"] = PROJECT_NAME;
body["version"] = PROJECT_VERSION; body["version"] = PROJECT_VERSION;

65
handler/login.cpp Normal file
View File

@ -0,0 +1,65 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#include "login.h"
#include "server/server.h"
#include "database/exceptions.h"
Handler::Login::Login(Server* server):
Handler("login", Request::Method::post),
server(server)
{}
void Handler::Login::handle(Request& request) {
std::map form = request.getForm();
std::map<std::string, std::string>::const_iterator itr = form.find("login");
if (itr == form.end())
return error(request, Result::noLogin, Response::Status::badRequest);
const std::string& login = itr->second;
if (login.empty())
return error(request, Result::emptyLogin, Response::Status::badRequest);
itr = form.find("password");
if (itr == form.end())
return error(request, Result::noPassword, Response::Status::badRequest);
const std::string& password = itr->second;
if (password.empty())
return error(request, Result::emptyPassword, Response::Status::badRequest);
bool success = false;
try {
success = server->validatePassword(login, password);
} catch (const DBInterface::NoLogin& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing
} catch (const std::exception& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
return error(request, Result::unknownError, Response::Status::internalError);
} catch (...) {
std::cerr << "Unknown exception on registration" << std::endl;
return error(request, Result::unknownError, Response::Status::internalError);
}
if (!success)
return error(request, Result::noLogin, Response::Status::badRequest);
//TODO opening the session
Response& res = request.createResponse();
nlohmann::json body = nlohmann::json::object();
body["result"] = Result::success;
res.setBody(body);
res.send();
}
void Handler::Login::error(Request& request, Result result, Response::Status code) {
Response& res = request.createResponse(code);
nlohmann::json body = nlohmann::json::object();
body["result"] = result;
res.setBody(body);
res.send();
}

32
handler/login.h Normal file
View File

@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#pragma once
#include "handler.h"
class Server;
namespace Handler {
class Login : public Handler {
public:
Login(Server* server);
void handle(Request& request) override;
enum class Result {
success,
noLogin,
emptyLogin,
noPassword,
emptyPassword,
unknownError
};
private:
void error(Request& request, Result result, Response::Status code);
private:
Server* server;
};
}

View File

@ -4,6 +4,7 @@
#include "register.h" #include "register.h"
#include "server/server.h" #include "server/server.h"
#include "database/exceptions.h"
Handler::Register::Register(Server* server): Handler::Register::Register(Server* server):
Handler("register", Request::Method::post), Handler("register", Request::Method::post),
@ -14,35 +15,38 @@ void Handler::Register::handle(Request& request) {
std::map form = request.getForm(); std::map form = request.getForm();
std::map<std::string, std::string>::const_iterator itr = form.find("login"); std::map<std::string, std::string>::const_iterator itr = form.find("login");
if (itr == form.end()) if (itr == form.end())
return error(request, Result::noLogin); return error(request, Result::noLogin, Response::Status::badRequest);
const std::string& login = itr->second; const std::string& login = itr->second;
if (login.empty()) if (login.empty())
return error(request, Result::emptyLogin); return error(request, Result::emptyLogin, Response::Status::badRequest);
//TODO login policies checkup //TODO login policies checkup
itr = form.find("password"); itr = form.find("password");
if (itr == form.end()) if (itr == form.end())
return error(request, Result::noPassword); return error(request, Result::noPassword, Response::Status::badRequest);
const std::string& password = itr->second; const std::string& password = itr->second;
if (password.empty()) if (password.empty())
return error(request, Result::emptyPassword); return error(request, Result::emptyPassword, Response::Status::badRequest);
//TODO password policies checkup //TODO password policies checkup
try { try {
server->registerAccount(login, password); server->registerAccount(login, password);
} catch (const DBInterface::DuplicateLogin& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
return error(request, Result::loginExists, Response::Status::conflict);
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl; std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
return error(request, Result::unknownError); return error(request, Result::unknownError, Response::Status::internalError);
} catch (...) { } catch (...) {
std::cerr << "Unknown exception on registration" << std::endl; std::cerr << "Unknown exception on registration" << std::endl;
return error(request, Result::unknownError); return error(request, Result::unknownError, Response::Status::internalError);
} }
Response res(request); Response& res = request.createResponse();
nlohmann::json body = nlohmann::json::object(); nlohmann::json body = nlohmann::json::object();
body["result"] = Result::success; body["result"] = Result::success;
@ -50,8 +54,8 @@ void Handler::Register::handle(Request& request) {
res.send(); res.send();
} }
void Handler::Register::error(Request& request, Result result) { void Handler::Register::error(Request& request, Result result, Response::Status code) {
Response res(request); Response& res = request.createResponse(code);
nlohmann::json body = nlohmann::json::object(); nlohmann::json body = nlohmann::json::object();
body["result"] = result; body["result"] = result;

View File

@ -8,10 +8,10 @@
class Server; class Server;
namespace Handler { namespace Handler {
class Register : public Handler::Handler { class Register : public Handler {
public: public:
Register(Server* server); Register(Server* server);
virtual void handle(Request& request); void handle(Request& request) override;
enum class Result { enum class Result {
success, success,
@ -26,7 +26,7 @@ public:
}; };
private: private:
void error(Request& request, Result result); void error(Request& request, Result result, Response::Status code);
private: private:
Server* server; Server* server;

View File

@ -5,8 +5,6 @@
#include "response/response.h" #include "response/response.h"
constexpr static const char* GET("GET");
constexpr static const char* REQUEST_METHOD("REQUEST_METHOD"); constexpr static const char* REQUEST_METHOD("REQUEST_METHOD");
constexpr static const char* SCRIPT_FILENAME("SCRIPT_FILENAME"); constexpr static const char* SCRIPT_FILENAME("SCRIPT_FILENAME");
constexpr static const char* SERVER_NAME("SERVER_NAME"); constexpr static const char* SERVER_NAME("SERVER_NAME");
@ -51,11 +49,15 @@ void Request::terminate() {
} }
} }
Request::Method Request::method() const { std::string_view Request::methodName() const {
if (state == State::initial) if (state == State::initial)
throw std::runtime_error("An attempt to read request method on not accepted request"); throw std::runtime_error("An attempt to read request method on not accepted request");
std::string_view method(FCGX_GetParam(REQUEST_METHOD, raw.envp)); return FCGX_GetParam(REQUEST_METHOD, raw.envp);
}
Request::Method Request::method() const {
std::string_view method = methodName();
for (const auto& pair : methods) { for (const auto& pair : methods) {
if (pair.first == method) if (pair.first == method)
return pair.second; return pair.second;
@ -79,17 +81,42 @@ bool Request::wait(int socketDescriptor) {
return result; return result;
} }
OStream Request::getOutputStream(const Response* response) { OStream Request::getOutputStream() {
validateResponse(response);
return OStream(raw.out); return OStream(raw.out);
} }
OStream Request::getErrorStream(const Response* response) { OStream Request::getErrorStream() {
validateResponse(response);
return OStream(raw.err); return OStream(raw.err);
} }
void Request::responseIsComplete(const Response* response) { Response& Request::createResponse() {
if (state != State::accepted)
throw std::runtime_error("An attempt create response to the request in the wrong state");
response = std::unique_ptr<Response>(new Response(*this));
state = State::responding;
return *response.get();
}
Response& Request::createResponse(Response::Status status) {
if (state != State::accepted)
throw std::runtime_error("An attempt create response to the request in the wrong state");
response = std::unique_ptr<Response>(new Response(*this, status));
state = State::responding;
return *response.get();
}
uint16_t Request::responseCode() const {
if (state != State::responded)
throw std::runtime_error("An attempt create read response code on the wrong state");
return response->statusCode();
}
void Request::responseIsComplete() {
switch (state) { switch (state) {
case State::initial: case State::initial:
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't even accepted yet"); throw std::runtime_error("An attempt to mark the request as complete, but it wasn't even accepted yet");
@ -98,10 +125,6 @@ void Request::responseIsComplete(const Response* response) {
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't responded"); throw std::runtime_error("An attempt to mark the request as complete, but it wasn't responded");
break; break;
case State::responding: case State::responding:
if (Request::response != response)
throw std::runtime_error("An attempt to mark the request as complete by the different response who actually started responding");
Request::response = nullptr;
state = State::responded; state = State::responded;
break; break;
case State::responded: case State::responded:
@ -110,26 +133,6 @@ void Request::responseIsComplete(const Response* response) {
} }
} }
void Request::validateResponse(const Response* response) {
switch (state) {
case State::initial:
throw std::runtime_error("An attempt to request stream while the request wasn't even accepted yet");
break;
case State::accepted:
Request::response = response;
state = State::responding;
break;
case State::responding:
if (Request::response != response)
throw std::runtime_error("Error handling a request: first time one response started replying, then another continued");
break;
case State::responded:
throw std::runtime_error("An attempt to request stream on a request that was already done responding");
break;
}
}
Request::State Request::currentState() const { Request::State Request::currentState() const {
return state; return state;
} }

View File

@ -16,9 +16,10 @@
#include "stream/ostream.h" #include "stream/ostream.h"
#include "utils/formdecode.h" #include "utils/formdecode.h"
#include "response/response.h"
class Response;
class Request { class Request {
friend class Response;
public: public:
enum class State { enum class State {
initial, initial,
@ -43,26 +44,29 @@ public:
bool wait(int socketDescriptor); bool wait(int socketDescriptor);
void terminate(); void terminate();
Response& createResponse();
Response& createResponse(Response::Status status);
uint16_t responseCode() const;
Method method() const; Method method() const;
std::string_view methodName() const;
State currentState() const; State currentState() const;
bool isFormUrlEncoded() const; bool isFormUrlEncoded() const;
unsigned int contentLength() const; unsigned int contentLength() const;
std::map<std::string, std::string> getForm() const; std::map<std::string, std::string> getForm() const;
OStream getOutputStream(const Response* response);
OStream getErrorStream(const Response* response);
void responseIsComplete(const Response* response);
std::string getPath(const std::string& serverName) const; std::string getPath(const std::string& serverName) const;
std::string getServerName() const; std::string getServerName() const;
void printEnvironment(std::ostream& out); void printEnvironment(std::ostream& out);
void printEnvironment(nlohmann::json& out); void printEnvironment(nlohmann::json& out);
private: private:
void validateResponse(const Response* response); OStream getOutputStream();
OStream getErrorStream();
void responseIsComplete();
private: private:
State state; State state;
FCGX_Request raw; FCGX_Request raw;
const Response* response; std::unique_ptr<Response> response;
}; };

View File

@ -3,10 +3,25 @@
#include "response.h" #include "response.h"
constexpr std::array<std::string_view, static_cast<uint8_t>(Response::Status::__size)> statusCodes = { #include "request/request.h"
constexpr std::array<uint16_t, static_cast<uint8_t>(Response::Status::__size)> statusCodes = {
200,
400,
401,
404,
405,
409,
500
};
constexpr std::array<std::string_view, static_cast<uint8_t>(Response::Status::__size)> statuses = {
"Status: 200 OK", "Status: 200 OK",
"Status: 400 Bad Request",
"Status: 401 Unauthorized",
"Status: 404 Not Found", "Status: 404 Not Found",
"Status: 405 Method Not Allowed", "Status: 405 Method Not Allowed",
"Status: 409 Conflict",
"Status: 500 Internal Error" "Status: 500 Internal Error"
}; };
@ -33,9 +48,9 @@ void Response::send() const {
// OStream out = status == Status::ok ? // OStream out = status == Status::ok ?
// request.getOutputStream() : // request.getOutputStream() :
// request.getErrorStream(); // request.getErrorStream();
OStream out = request.getOutputStream(this); OStream out = request.getOutputStream();
out << statusCodes[static_cast<uint8_t>(status)]; out << statuses[static_cast<uint8_t>(status)];
if (!body.empty()) if (!body.empty())
out << '\n' out << '\n'
<< contentTypes[static_cast<uint8_t>(type)] << contentTypes[static_cast<uint8_t>(type)]
@ -43,7 +58,11 @@ void Response::send() const {
<< '\n' << '\n'
<< body; << body;
request.responseIsComplete(this); request.responseIsComplete();
}
uint16_t Response::statusCode() const {
return statusCodes[static_cast<uint8_t>(status)];
} }
void Response::setBody(const std::string& body) { void Response::setBody(const std::string& body) {

View File

@ -9,15 +9,20 @@
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "request/request.h"
#include "stream/ostream.h" #include "stream/ostream.h"
class Request;
class Response { class Response {
friend class Request;
public: public:
enum class Status { enum class Status {
ok, ok,
badRequest,
unauthorized,
notFound, notFound,
methodNotAllowed, methodNotAllowed,
conflict,
internalError, internalError,
__size __size
}; };
@ -27,13 +32,17 @@ public:
json, json,
__size __size
}; };
Response(Request& request);
Response(Request& request, Status status); uint16_t statusCode() const;
void send() const; void send() const;
void setBody(const std::string& body); void setBody(const std::string& body);
void setBody(const nlohmann::json& body); void setBody(const nlohmann::json& body);
private:
Response(Request& request);
Response(Request& request, Status status);
private: private:
Request& request; Request& request;
Status status; Status status;

View File

@ -50,29 +50,29 @@ void Router::route(const std::string& path, std::unique_ptr<Request> request) {
if (request->currentState() != Request::State::responded) if (request->currentState() != Request::State::responded)
handleInternalError(path, std::runtime_error("handler failed to handle the request"), std::move(request)); handleInternalError(path, std::runtime_error("handler failed to handle the request"), std::move(request));
else else
std::cout << "Success:\t" << path << std::endl; std::cout << request->responseCode() << '\t' << request->methodName() << '\t' << path << std::endl;
} catch (const std::exception& e) { } catch (const std::exception& e) {
handleInternalError(path, e, std::move(request)); handleInternalError(path, e, std::move(request));
} }
} }
void Router::handleNotFound(const std::string& path, std::unique_ptr<Request> request) { void Router::handleNotFound(const std::string& path, std::unique_ptr<Request> request) {
Response notFound(*request.get(), Response::Status::notFound); Response& notFound = request->createResponse(Response::Status::notFound);
notFound.setBody(std::string("Path \"") + path + "\" was not found"); notFound.setBody(std::string("Path \"") + path + "\" was not found");
notFound.send(); notFound.send();
std::cerr << "Not found:\t" << path << std::endl; std::cerr << notFound.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
} }
void Router::handleInternalError(const std::string& path, const std::exception& exception, std::unique_ptr<Request> request) { void Router::handleInternalError(const std::string& path, const std::exception& exception, std::unique_ptr<Request> request) {
Response error(*request.get(), Response::Status::internalError); Response& error = request->createResponse(Response::Status::internalError);
error.setBody(std::string(exception.what())); error.setBody(std::string(exception.what()));
error.send(); error.send();
std::cerr << "Internal error:\t" << path << "\n\t" << exception.what() << std::endl; std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
} }
void Router::handleMethodNotAllowed(const std::string& path, std::unique_ptr<Request> request) { void Router::handleMethodNotAllowed(const std::string& path, std::unique_ptr<Request> request) {
Response error(*request.get(), Response::Status::methodNotAllowed); Response& error = request->createResponse(Response::Status::methodNotAllowed);
error.setBody(std::string("Method not allowed")); error.setBody(std::string("Method not allowed"));
error.send(); error.send();
std::cerr << "Method not allowed:\t" << path << std::endl; std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
} }

View File

@ -8,6 +8,7 @@
#include "handler/info.h" #include "handler/info.h"
#include "handler/env.h" #include "handler/env.h"
#include "handler/register.h" #include "handler/register.h"
#include "handler/login.h"
constexpr const char* pepper = "well, not much of a secret, huh?"; constexpr const char* pepper = "well, not much of a secret, huh?";
constexpr uint8_t currentDbVesion = 1; constexpr uint8_t currentDbVesion = 1;
@ -39,6 +40,7 @@ Server::Server():
router.addRoute(std::make_unique<Handler::Info>()); router.addRoute(std::make_unique<Handler::Info>());
router.addRoute(std::make_unique<Handler::Env>()); router.addRoute(std::make_unique<Handler::Env>());
router.addRoute(std::make_unique<Handler::Register>(this)); router.addRoute(std::make_unique<Handler::Register>(this));
router.addRoute(std::make_unique<Handler::Login>(this));
} }
Server::~Server() {} Server::~Server() {}
@ -63,7 +65,7 @@ void Server::handleRequest(std::unique_ptr<Request> request) {
std::cout << "received server name " << serverName.value() << std::endl; std::cout << "received server name " << serverName.value() << std::endl;
} catch (...) { } catch (...) {
std::cerr << "failed to read server name" << std::endl; std::cerr << "failed to read server name" << std::endl;
Response error(*request.get(), Response::Status::internalError); Response& error = request->createResponse(Response::Status::internalError);
error.send(); error.send();
return; return;
} }
@ -107,3 +109,19 @@ unsigned int Server::registerAccount(const std::string& login, const std::string
return db->registerAccount(login, hash); return db->registerAccount(login, hash);
} }
bool Server::validatePassword(const std::string& login, const std::string& password) {
std::string hash = db->getAccountHash(login);
std::string spiced = password + pepper;
int result = argon2id_verify(hash.data(), spiced.data(), spiced.size());
switch (result) {
case ARGON2_OK:
return true;
case ARGON2_VERIFY_MISMATCH:
return false;
default:
throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result));
}
}

View File

@ -32,6 +32,7 @@ public:
void run(int socketDescriptor); void run(int socketDescriptor);
unsigned int registerAccount(const std::string& login, const std::string& password); unsigned int registerAccount(const std::string& login, const std::string& password);
bool validatePassword(const std::string& login, const std::string& password);
private: private:
void handleRequest(std::unique_ptr<Request> request); void handleRequest(std::unique_ptr<Request> request);