password hash cheching
This commit is contained in:
parent
99a9fd507e
commit
534c282226
@ -1,9 +1,11 @@
|
||||
set(HEADERS
|
||||
dbinterface.h
|
||||
exceptions.h
|
||||
)
|
||||
|
||||
set(SOURCES
|
||||
dbinterface.cpp
|
||||
exceptions.cpp
|
||||
)
|
||||
|
||||
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})
|
||||
|
@ -26,6 +26,11 @@ public:
|
||||
|
||||
const Type type;
|
||||
|
||||
class Duplicate;
|
||||
class DuplicateLogin;
|
||||
class EmptyResult;
|
||||
class NoLogin;
|
||||
|
||||
public:
|
||||
virtual void connect(const std::string& path) = 0;
|
||||
virtual void disconnect() = 0;
|
||||
@ -37,6 +42,7 @@ public:
|
||||
virtual void setVersion(uint8_t version) = 0;
|
||||
|
||||
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
|
||||
virtual std::string getAccountHash(const std::string& login) = 0;
|
||||
|
||||
protected:
|
||||
DBInterface(Type type);
|
||||
|
20
database/exceptions.cpp
Normal file
20
database/exceptions.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
#include "exceptions.h"
|
||||
|
||||
DBInterface::Duplicate::Duplicate(const std::string& text):
|
||||
std::runtime_error(text)
|
||||
{}
|
||||
|
||||
DBInterface::DuplicateLogin::DuplicateLogin(const std::string& text):
|
||||
Duplicate(text)
|
||||
{}
|
||||
|
||||
DBInterface::EmptyResult::EmptyResult(const std::string& text):
|
||||
std::runtime_error(text)
|
||||
{}
|
||||
|
||||
DBInterface::NoLogin::NoLogin(const std::string& text):
|
||||
EmptyResult(text)
|
||||
{}
|
26
database/exceptions.h
Normal file
26
database/exceptions.h
Normal file
@ -0,0 +1,26 @@
|
||||
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "dbinterface.h"
|
||||
|
||||
class DBInterface::Duplicate : public std::runtime_error {
|
||||
public:
|
||||
explicit Duplicate(const std::string& text);
|
||||
};
|
||||
|
||||
class DBInterface::DuplicateLogin : public DBInterface::Duplicate {
|
||||
public:
|
||||
explicit DuplicateLogin(const std::string& text);
|
||||
};
|
||||
|
||||
class DBInterface::EmptyResult : public std::runtime_error {
|
||||
public:
|
||||
explicit EmptyResult(const std::string& text);
|
||||
};
|
||||
|
||||
class DBInterface::NoLogin : public DBInterface::EmptyResult {
|
||||
public:
|
||||
explicit NoLogin(const std::string& text);
|
||||
};
|
@ -10,21 +10,17 @@
|
||||
|
||||
#include "statement.h"
|
||||
#include "transaction.h"
|
||||
#include "database/exceptions.h"
|
||||
|
||||
constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'";
|
||||
constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'";
|
||||
constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)";
|
||||
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
|
||||
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
|
||||
constexpr const char* selectHash = "SELECT password FROM accounts where login = ?";
|
||||
|
||||
static const std::filesystem::path buildSQLPath = "database";
|
||||
|
||||
struct ResDeleter {
|
||||
void operator () (MYSQL_RES* res) {
|
||||
mysql_free_result(res);
|
||||
}
|
||||
};
|
||||
|
||||
MySQL::MySQL():
|
||||
DBInterface(Type::mysql),
|
||||
connection(),
|
||||
@ -231,7 +227,11 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string&
|
||||
std::string h = hash;
|
||||
addAcc.bind(l.data(), MYSQL_TYPE_STRING);
|
||||
addAcc.bind(h.data(), MYSQL_TYPE_STRING);
|
||||
addAcc.execute();
|
||||
try {
|
||||
addAcc.execute();
|
||||
} catch (const Duplicate& dup) {
|
||||
throw DuplicateLogin(dup.what());
|
||||
}
|
||||
|
||||
unsigned int id = lastInsertedId();
|
||||
static std::string defaultRole("default");
|
||||
@ -245,6 +245,24 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string&
|
||||
return id;
|
||||
}
|
||||
|
||||
std::string MySQL::getAccountHash(const std::string& login) {
|
||||
std::string l = login;
|
||||
MYSQL* con = &connection;
|
||||
|
||||
Statement getHash(con, selectHash);
|
||||
getHash.bind(l.data(), MYSQL_TYPE_STRING);
|
||||
getHash.execute();
|
||||
|
||||
std::vector<std::vector<std::string>> result = getHash.fetchResult();
|
||||
if (result.empty())
|
||||
throw NoLogin("Couldn't find login " + l);
|
||||
|
||||
if (result[0].empty())
|
||||
throw std::runtime_error("Error with the query \"selectHash\"");
|
||||
|
||||
return result[0][0];
|
||||
}
|
||||
|
||||
unsigned int MySQL::lastInsertedId() {
|
||||
MYSQL* con = &connection;
|
||||
int result = mysql_query(con, lastIdQuery);
|
||||
|
@ -15,6 +15,8 @@
|
||||
class MySQL : public DBInterface {
|
||||
class Statement;
|
||||
class Transaction;
|
||||
|
||||
|
||||
public:
|
||||
MySQL();
|
||||
~MySQL() override;
|
||||
@ -29,6 +31,7 @@ public:
|
||||
void setVersion(uint8_t version) override;
|
||||
|
||||
unsigned int registerAccount(const std::string& login, const std::string& hash) override;
|
||||
std::string getAccountHash(const std::string& login) override;
|
||||
|
||||
private:
|
||||
void executeFile(const std::filesystem::path& relativePath);
|
||||
@ -40,4 +43,10 @@ protected:
|
||||
std::string login;
|
||||
std::string password;
|
||||
std::string database;
|
||||
|
||||
struct ResDeleter {
|
||||
void operator () (MYSQL_RES* res) {
|
||||
mysql_free_result(res);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
@ -5,6 +5,10 @@
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mysqld_error.h"
|
||||
|
||||
#include "database/exceptions.h"
|
||||
|
||||
static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME);
|
||||
|
||||
MySQL::Statement::Statement(MYSQL* connection, const char* statement):
|
||||
@ -44,11 +48,68 @@ void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) {
|
||||
}
|
||||
|
||||
void MySQL::Statement::execute() {
|
||||
int result = mysql_stmt_bind_param(stmt.get(), param.data());
|
||||
MYSQL_STMT* raw = stmt.get();
|
||||
int result = mysql_stmt_bind_param(raw, param.data());
|
||||
if (result != 0)
|
||||
throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(stmt.get()));
|
||||
throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(raw));
|
||||
|
||||
result = mysql_stmt_execute(stmt.get());
|
||||
if (result != 0)
|
||||
throw std::runtime_error(std::string("Error executing statement: ") + mysql_stmt_error(stmt.get()));
|
||||
result = mysql_stmt_execute(raw);
|
||||
if (result != 0) {
|
||||
int errcode = mysql_stmt_errno(raw);
|
||||
std::string text = mysql_stmt_error(raw);
|
||||
switch (errcode) {
|
||||
case ER_DUP_ENTRY:
|
||||
throw Duplicate("Error executing statement: " + text);
|
||||
default:
|
||||
throw std::runtime_error("Error executing statement: " + text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() {
|
||||
MYSQL_STMT* raw = stmt.get();
|
||||
if (mysql_stmt_store_result(raw) != 0)
|
||||
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
|
||||
|
||||
MYSQL_RES* meta = mysql_stmt_result_metadata(raw);
|
||||
if (meta == nullptr)
|
||||
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
|
||||
|
||||
std::unique_ptr<MYSQL_RES, ResDeleter> mt(meta);
|
||||
unsigned int numColumns = mysql_num_fields(meta);
|
||||
MYSQL_BIND bind[numColumns];
|
||||
memset(bind, 0, sizeof(bind));
|
||||
|
||||
std::vector<std::string> line(numColumns);
|
||||
std::vector<long unsigned int> lengths(numColumns);
|
||||
for (unsigned int i = 0; i < numColumns; ++i) {
|
||||
MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i);
|
||||
|
||||
switch (field->type) {
|
||||
case MYSQL_TYPE_STRING:
|
||||
case MYSQL_TYPE_VAR_STRING:
|
||||
case MYSQL_TYPE_VARCHAR:
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type));
|
||||
}
|
||||
line[i].resize(field->length);
|
||||
bind[i].buffer_type = field->type;
|
||||
bind[i].buffer = line[i].data();
|
||||
bind[i].buffer_length = field->length;
|
||||
bind[i].length = &lengths[i];
|
||||
}
|
||||
|
||||
if (mysql_stmt_bind_result(raw, bind) != 0)
|
||||
throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw));
|
||||
|
||||
std::vector<std::vector<std::string>> result;
|
||||
while (mysql_stmt_fetch(raw) == 0) {
|
||||
std::vector<std::string>& row = result.emplace_back(numColumns);
|
||||
for (unsigned int i = 0; i < numColumns; ++i)
|
||||
row[i] = std::string(line[i].data(), lengths[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ public:
|
||||
|
||||
void bind(void* value, enum_field_types type, bool usigned = false);
|
||||
void execute();
|
||||
std::vector<std::vector<std::string>> fetchResult();
|
||||
|
||||
private:
|
||||
std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt;
|
||||
|
@ -1,3 +1,6 @@
|
||||
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
#include "transaction.h"
|
||||
|
||||
MySQL::Transaction::Transaction(MYSQL* connection):
|
||||
|
@ -1,3 +1,6 @@
|
||||
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mysql.h"
|
||||
|
@ -3,6 +3,7 @@ set(HEADERS
|
||||
info.h
|
||||
env.h
|
||||
register.h
|
||||
login.h
|
||||
)
|
||||
|
||||
set(SOURCES
|
||||
@ -10,6 +11,7 @@ set(SOURCES
|
||||
info.cpp
|
||||
env.cpp
|
||||
register.cpp
|
||||
login.cpp
|
||||
)
|
||||
|
||||
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})
|
||||
|
@ -11,7 +11,7 @@ void Handler::Env::handle(Request& request) {
|
||||
nlohmann::json body = nlohmann::json::object();
|
||||
request.printEnvironment(body);
|
||||
|
||||
Response res(request);
|
||||
Response& res = request.createResponse();
|
||||
res.setBody(body);
|
||||
res.send();
|
||||
}
|
||||
|
@ -7,10 +7,10 @@
|
||||
|
||||
namespace Handler {
|
||||
|
||||
class Env : public Handler::Handler {
|
||||
class Env : public Handler {
|
||||
public:
|
||||
Env();
|
||||
virtual void handle(Request& request);
|
||||
void handle(Request& request) override;
|
||||
|
||||
};
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ Handler::Info::Info():
|
||||
{}
|
||||
|
||||
void Handler::Info::handle(Request& request) {
|
||||
Response res(request);
|
||||
Response& res = request.createResponse();
|
||||
nlohmann::json body = nlohmann::json::object();
|
||||
body["type"] = PROJECT_NAME;
|
||||
body["version"] = PROJECT_VERSION;
|
||||
|
65
handler/login.cpp
Normal file
65
handler/login.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
#include "login.h"
|
||||
|
||||
#include "server/server.h"
|
||||
#include "database/exceptions.h"
|
||||
|
||||
Handler::Login::Login(Server* server):
|
||||
Handler("login", Request::Method::post),
|
||||
server(server)
|
||||
{}
|
||||
|
||||
void Handler::Login::handle(Request& request) {
|
||||
std::map form = request.getForm();
|
||||
std::map<std::string, std::string>::const_iterator itr = form.find("login");
|
||||
if (itr == form.end())
|
||||
return error(request, Result::noLogin, Response::Status::badRequest);
|
||||
|
||||
const std::string& login = itr->second;
|
||||
if (login.empty())
|
||||
return error(request, Result::emptyLogin, Response::Status::badRequest);
|
||||
|
||||
itr = form.find("password");
|
||||
if (itr == form.end())
|
||||
return error(request, Result::noPassword, Response::Status::badRequest);
|
||||
|
||||
const std::string& password = itr->second;
|
||||
if (password.empty())
|
||||
return error(request, Result::emptyPassword, Response::Status::badRequest);
|
||||
|
||||
bool success = false;
|
||||
try {
|
||||
success = server->validatePassword(login, password);
|
||||
} catch (const DBInterface::NoLogin& e) {
|
||||
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||
return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||
return error(request, Result::unknownError, Response::Status::internalError);
|
||||
} catch (...) {
|
||||
std::cerr << "Unknown exception on registration" << std::endl;
|
||||
return error(request, Result::unknownError, Response::Status::internalError);
|
||||
}
|
||||
if (!success)
|
||||
return error(request, Result::noLogin, Response::Status::badRequest);
|
||||
|
||||
//TODO opening the session
|
||||
|
||||
Response& res = request.createResponse();
|
||||
nlohmann::json body = nlohmann::json::object();
|
||||
body["result"] = Result::success;
|
||||
|
||||
res.setBody(body);
|
||||
res.send();
|
||||
}
|
||||
|
||||
void Handler::Login::error(Request& request, Result result, Response::Status code) {
|
||||
Response& res = request.createResponse(code);
|
||||
nlohmann::json body = nlohmann::json::object();
|
||||
body["result"] = result;
|
||||
|
||||
res.setBody(body);
|
||||
res.send();
|
||||
}
|
32
handler/login.h
Normal file
32
handler/login.h
Normal file
@ -0,0 +1,32 @@
|
||||
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "handler.h"
|
||||
|
||||
class Server;
|
||||
namespace Handler {
|
||||
|
||||
class Login : public Handler {
|
||||
public:
|
||||
Login(Server* server);
|
||||
void handle(Request& request) override;
|
||||
|
||||
enum class Result {
|
||||
success,
|
||||
noLogin,
|
||||
emptyLogin,
|
||||
noPassword,
|
||||
emptyPassword,
|
||||
unknownError
|
||||
};
|
||||
|
||||
private:
|
||||
void error(Request& request, Result result, Response::Status code);
|
||||
|
||||
private:
|
||||
Server* server;
|
||||
};
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include "register.h"
|
||||
|
||||
#include "server/server.h"
|
||||
#include "database/exceptions.h"
|
||||
|
||||
Handler::Register::Register(Server* server):
|
||||
Handler("register", Request::Method::post),
|
||||
@ -14,35 +15,38 @@ void Handler::Register::handle(Request& request) {
|
||||
std::map form = request.getForm();
|
||||
std::map<std::string, std::string>::const_iterator itr = form.find("login");
|
||||
if (itr == form.end())
|
||||
return error(request, Result::noLogin);
|
||||
return error(request, Result::noLogin, Response::Status::badRequest);
|
||||
|
||||
const std::string& login = itr->second;
|
||||
if (login.empty())
|
||||
return error(request, Result::emptyLogin);
|
||||
return error(request, Result::emptyLogin, Response::Status::badRequest);
|
||||
|
||||
//TODO login policies checkup
|
||||
|
||||
itr = form.find("password");
|
||||
if (itr == form.end())
|
||||
return error(request, Result::noPassword);
|
||||
return error(request, Result::noPassword, Response::Status::badRequest);
|
||||
|
||||
const std::string& password = itr->second;
|
||||
if (password.empty())
|
||||
return error(request, Result::emptyPassword);
|
||||
return error(request, Result::emptyPassword, Response::Status::badRequest);
|
||||
|
||||
//TODO password policies checkup
|
||||
|
||||
try {
|
||||
server->registerAccount(login, password);
|
||||
} catch (const DBInterface::DuplicateLogin& e) {
|
||||
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||
return error(request, Result::loginExists, Response::Status::conflict);
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||
return error(request, Result::unknownError);
|
||||
} catch (...) {
|
||||
return error(request, Result::unknownError, Response::Status::internalError);
|
||||
} catch (...) {
|
||||
std::cerr << "Unknown exception on registration" << std::endl;
|
||||
return error(request, Result::unknownError);
|
||||
return error(request, Result::unknownError, Response::Status::internalError);
|
||||
}
|
||||
|
||||
Response res(request);
|
||||
Response& res = request.createResponse();
|
||||
nlohmann::json body = nlohmann::json::object();
|
||||
body["result"] = Result::success;
|
||||
|
||||
@ -50,8 +54,8 @@ void Handler::Register::handle(Request& request) {
|
||||
res.send();
|
||||
}
|
||||
|
||||
void Handler::Register::error(Request& request, Result result) {
|
||||
Response res(request);
|
||||
void Handler::Register::error(Request& request, Result result, Response::Status code) {
|
||||
Response& res = request.createResponse(code);
|
||||
nlohmann::json body = nlohmann::json::object();
|
||||
body["result"] = result;
|
||||
|
||||
|
@ -8,10 +8,10 @@
|
||||
class Server;
|
||||
namespace Handler {
|
||||
|
||||
class Register : public Handler::Handler {
|
||||
class Register : public Handler {
|
||||
public:
|
||||
Register(Server* server);
|
||||
virtual void handle(Request& request);
|
||||
void handle(Request& request) override;
|
||||
|
||||
enum class Result {
|
||||
success,
|
||||
@ -26,7 +26,7 @@ public:
|
||||
};
|
||||
|
||||
private:
|
||||
void error(Request& request, Result result);
|
||||
void error(Request& request, Result result, Response::Status code);
|
||||
|
||||
private:
|
||||
Server* server;
|
||||
|
@ -5,8 +5,6 @@
|
||||
|
||||
#include "response/response.h"
|
||||
|
||||
constexpr static const char* GET("GET");
|
||||
|
||||
constexpr static const char* REQUEST_METHOD("REQUEST_METHOD");
|
||||
constexpr static const char* SCRIPT_FILENAME("SCRIPT_FILENAME");
|
||||
constexpr static const char* SERVER_NAME("SERVER_NAME");
|
||||
@ -51,11 +49,15 @@ void Request::terminate() {
|
||||
}
|
||||
}
|
||||
|
||||
Request::Method Request::method() const {
|
||||
std::string_view Request::methodName() const {
|
||||
if (state == State::initial)
|
||||
throw std::runtime_error("An attempt to read request method on not accepted request");
|
||||
|
||||
std::string_view method(FCGX_GetParam(REQUEST_METHOD, raw.envp));
|
||||
return FCGX_GetParam(REQUEST_METHOD, raw.envp);
|
||||
}
|
||||
|
||||
Request::Method Request::method() const {
|
||||
std::string_view method = methodName();
|
||||
for (const auto& pair : methods) {
|
||||
if (pair.first == method)
|
||||
return pair.second;
|
||||
@ -79,17 +81,42 @@ bool Request::wait(int socketDescriptor) {
|
||||
return result;
|
||||
}
|
||||
|
||||
OStream Request::getOutputStream(const Response* response) {
|
||||
validateResponse(response);
|
||||
OStream Request::getOutputStream() {
|
||||
return OStream(raw.out);
|
||||
}
|
||||
|
||||
OStream Request::getErrorStream(const Response* response) {
|
||||
validateResponse(response);
|
||||
OStream Request::getErrorStream() {
|
||||
return OStream(raw.err);
|
||||
}
|
||||
|
||||
void Request::responseIsComplete(const Response* response) {
|
||||
Response& Request::createResponse() {
|
||||
if (state != State::accepted)
|
||||
throw std::runtime_error("An attempt create response to the request in the wrong state");
|
||||
|
||||
response = std::unique_ptr<Response>(new Response(*this));
|
||||
state = State::responding;
|
||||
|
||||
return *response.get();
|
||||
}
|
||||
|
||||
Response& Request::createResponse(Response::Status status) {
|
||||
if (state != State::accepted)
|
||||
throw std::runtime_error("An attempt create response to the request in the wrong state");
|
||||
|
||||
response = std::unique_ptr<Response>(new Response(*this, status));
|
||||
state = State::responding;
|
||||
|
||||
return *response.get();
|
||||
}
|
||||
|
||||
uint16_t Request::responseCode() const {
|
||||
if (state != State::responded)
|
||||
throw std::runtime_error("An attempt create read response code on the wrong state");
|
||||
|
||||
return response->statusCode();
|
||||
}
|
||||
|
||||
void Request::responseIsComplete() {
|
||||
switch (state) {
|
||||
case State::initial:
|
||||
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't even accepted yet");
|
||||
@ -98,10 +125,6 @@ void Request::responseIsComplete(const Response* response) {
|
||||
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't responded");
|
||||
break;
|
||||
case State::responding:
|
||||
if (Request::response != response)
|
||||
throw std::runtime_error("An attempt to mark the request as complete by the different response who actually started responding");
|
||||
|
||||
Request::response = nullptr;
|
||||
state = State::responded;
|
||||
break;
|
||||
case State::responded:
|
||||
@ -110,26 +133,6 @@ void Request::responseIsComplete(const Response* response) {
|
||||
}
|
||||
}
|
||||
|
||||
void Request::validateResponse(const Response* response) {
|
||||
switch (state) {
|
||||
case State::initial:
|
||||
throw std::runtime_error("An attempt to request stream while the request wasn't even accepted yet");
|
||||
break;
|
||||
case State::accepted:
|
||||
Request::response = response;
|
||||
state = State::responding;
|
||||
break;
|
||||
case State::responding:
|
||||
if (Request::response != response)
|
||||
throw std::runtime_error("Error handling a request: first time one response started replying, then another continued");
|
||||
|
||||
break;
|
||||
case State::responded:
|
||||
throw std::runtime_error("An attempt to request stream on a request that was already done responding");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Request::State Request::currentState() const {
|
||||
return state;
|
||||
}
|
||||
|
@ -16,9 +16,10 @@
|
||||
|
||||
#include "stream/ostream.h"
|
||||
#include "utils/formdecode.h"
|
||||
#include "response/response.h"
|
||||
|
||||
class Response;
|
||||
class Request {
|
||||
friend class Response;
|
||||
public:
|
||||
enum class State {
|
||||
initial,
|
||||
@ -43,26 +44,29 @@ public:
|
||||
bool wait(int socketDescriptor);
|
||||
void terminate();
|
||||
|
||||
Response& createResponse();
|
||||
Response& createResponse(Response::Status status);
|
||||
|
||||
uint16_t responseCode() const;
|
||||
Method method() const;
|
||||
std::string_view methodName() const;
|
||||
State currentState() const;
|
||||
bool isFormUrlEncoded() const;
|
||||
unsigned int contentLength() const;
|
||||
std::map<std::string, std::string> getForm() const;
|
||||
|
||||
OStream getOutputStream(const Response* response);
|
||||
OStream getErrorStream(const Response* response);
|
||||
void responseIsComplete(const Response* response);
|
||||
|
||||
std::string getPath(const std::string& serverName) const;
|
||||
std::string getServerName() const;
|
||||
void printEnvironment(std::ostream& out);
|
||||
void printEnvironment(nlohmann::json& out);
|
||||
|
||||
private:
|
||||
void validateResponse(const Response* response);
|
||||
OStream getOutputStream();
|
||||
OStream getErrorStream();
|
||||
void responseIsComplete();
|
||||
|
||||
private:
|
||||
State state;
|
||||
FCGX_Request raw;
|
||||
const Response* response;
|
||||
std::unique_ptr<Response> response;
|
||||
};
|
||||
|
@ -3,10 +3,25 @@
|
||||
|
||||
#include "response.h"
|
||||
|
||||
constexpr std::array<std::string_view, static_cast<uint8_t>(Response::Status::__size)> statusCodes = {
|
||||
#include "request/request.h"
|
||||
|
||||
constexpr std::array<uint16_t, static_cast<uint8_t>(Response::Status::__size)> statusCodes = {
|
||||
200,
|
||||
400,
|
||||
401,
|
||||
404,
|
||||
405,
|
||||
409,
|
||||
500
|
||||
};
|
||||
|
||||
constexpr std::array<std::string_view, static_cast<uint8_t>(Response::Status::__size)> statuses = {
|
||||
"Status: 200 OK",
|
||||
"Status: 400 Bad Request",
|
||||
"Status: 401 Unauthorized",
|
||||
"Status: 404 Not Found",
|
||||
"Status: 405 Method Not Allowed",
|
||||
"Status: 409 Conflict",
|
||||
"Status: 500 Internal Error"
|
||||
};
|
||||
|
||||
@ -33,9 +48,9 @@ void Response::send() const {
|
||||
// OStream out = status == Status::ok ?
|
||||
// request.getOutputStream() :
|
||||
// request.getErrorStream();
|
||||
OStream out = request.getOutputStream(this);
|
||||
OStream out = request.getOutputStream();
|
||||
|
||||
out << statusCodes[static_cast<uint8_t>(status)];
|
||||
out << statuses[static_cast<uint8_t>(status)];
|
||||
if (!body.empty())
|
||||
out << '\n'
|
||||
<< contentTypes[static_cast<uint8_t>(type)]
|
||||
@ -43,7 +58,11 @@ void Response::send() const {
|
||||
<< '\n'
|
||||
<< body;
|
||||
|
||||
request.responseIsComplete(this);
|
||||
request.responseIsComplete();
|
||||
}
|
||||
|
||||
uint16_t Response::statusCode() const {
|
||||
return statusCodes[static_cast<uint8_t>(status)];
|
||||
}
|
||||
|
||||
void Response::setBody(const std::string& body) {
|
||||
|
@ -9,15 +9,20 @@
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "request/request.h"
|
||||
#include "stream/ostream.h"
|
||||
|
||||
class Request;
|
||||
class Response {
|
||||
friend class Request;
|
||||
|
||||
public:
|
||||
enum class Status {
|
||||
ok,
|
||||
badRequest,
|
||||
unauthorized,
|
||||
notFound,
|
||||
methodNotAllowed,
|
||||
conflict,
|
||||
internalError,
|
||||
__size
|
||||
};
|
||||
@ -27,13 +32,17 @@ public:
|
||||
json,
|
||||
__size
|
||||
};
|
||||
Response(Request& request);
|
||||
Response(Request& request, Status status);
|
||||
|
||||
uint16_t statusCode() const;
|
||||
|
||||
void send() const;
|
||||
void setBody(const std::string& body);
|
||||
void setBody(const nlohmann::json& body);
|
||||
|
||||
private:
|
||||
Response(Request& request);
|
||||
Response(Request& request, Status status);
|
||||
|
||||
private:
|
||||
Request& request;
|
||||
Status status;
|
||||
|
@ -50,29 +50,29 @@ void Router::route(const std::string& path, std::unique_ptr<Request> request) {
|
||||
if (request->currentState() != Request::State::responded)
|
||||
handleInternalError(path, std::runtime_error("handler failed to handle the request"), std::move(request));
|
||||
else
|
||||
std::cout << "Success:\t" << path << std::endl;
|
||||
std::cout << request->responseCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||
} catch (const std::exception& e) {
|
||||
handleInternalError(path, e, std::move(request));
|
||||
}
|
||||
}
|
||||
|
||||
void Router::handleNotFound(const std::string& path, std::unique_ptr<Request> request) {
|
||||
Response notFound(*request.get(), Response::Status::notFound);
|
||||
Response& notFound = request->createResponse(Response::Status::notFound);
|
||||
notFound.setBody(std::string("Path \"") + path + "\" was not found");
|
||||
notFound.send();
|
||||
std::cerr << "Not found:\t" << path << std::endl;
|
||||
std::cerr << notFound.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||
}
|
||||
|
||||
void Router::handleInternalError(const std::string& path, const std::exception& exception, std::unique_ptr<Request> request) {
|
||||
Response error(*request.get(), Response::Status::internalError);
|
||||
Response& error = request->createResponse(Response::Status::internalError);
|
||||
error.setBody(std::string(exception.what()));
|
||||
error.send();
|
||||
std::cerr << "Internal error:\t" << path << "\n\t" << exception.what() << std::endl;
|
||||
std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||
}
|
||||
|
||||
void Router::handleMethodNotAllowed(const std::string& path, std::unique_ptr<Request> request) {
|
||||
Response error(*request.get(), Response::Status::methodNotAllowed);
|
||||
Response& error = request->createResponse(Response::Status::methodNotAllowed);
|
||||
error.setBody(std::string("Method not allowed"));
|
||||
error.send();
|
||||
std::cerr << "Method not allowed:\t" << path << std::endl;
|
||||
std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "handler/info.h"
|
||||
#include "handler/env.h"
|
||||
#include "handler/register.h"
|
||||
#include "handler/login.h"
|
||||
|
||||
constexpr const char* pepper = "well, not much of a secret, huh?";
|
||||
constexpr uint8_t currentDbVesion = 1;
|
||||
@ -39,6 +40,7 @@ Server::Server():
|
||||
router.addRoute(std::make_unique<Handler::Info>());
|
||||
router.addRoute(std::make_unique<Handler::Env>());
|
||||
router.addRoute(std::make_unique<Handler::Register>(this));
|
||||
router.addRoute(std::make_unique<Handler::Login>(this));
|
||||
}
|
||||
|
||||
Server::~Server() {}
|
||||
@ -63,7 +65,7 @@ void Server::handleRequest(std::unique_ptr<Request> request) {
|
||||
std::cout << "received server name " << serverName.value() << std::endl;
|
||||
} catch (...) {
|
||||
std::cerr << "failed to read server name" << std::endl;
|
||||
Response error(*request.get(), Response::Status::internalError);
|
||||
Response& error = request->createResponse(Response::Status::internalError);
|
||||
error.send();
|
||||
return;
|
||||
}
|
||||
@ -107,3 +109,19 @@ unsigned int Server::registerAccount(const std::string& login, const std::string
|
||||
|
||||
return db->registerAccount(login, hash);
|
||||
}
|
||||
|
||||
bool Server::validatePassword(const std::string& login, const std::string& password) {
|
||||
std::string hash = db->getAccountHash(login);
|
||||
|
||||
std::string spiced = password + pepper;
|
||||
int result = argon2id_verify(hash.data(), spiced.data(), spiced.size());
|
||||
|
||||
switch (result) {
|
||||
case ARGON2_OK:
|
||||
return true;
|
||||
case ARGON2_VERIFY_MISMATCH:
|
||||
return false;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result));
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ public:
|
||||
void run(int socketDescriptor);
|
||||
|
||||
unsigned int registerAccount(const std::string& login, const std::string& password);
|
||||
bool validatePassword(const std::string& login, const std::string& password);
|
||||
|
||||
private:
|
||||
void handleRequest(std::unique_ptr<Request> request);
|
||||
|
Loading…
Reference in New Issue
Block a user