password hash cheching
This commit is contained in:
parent
99a9fd507e
commit
534c282226
@ -1,9 +1,11 @@
|
|||||||
set(HEADERS
|
set(HEADERS
|
||||||
dbinterface.h
|
dbinterface.h
|
||||||
|
exceptions.h
|
||||||
)
|
)
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
dbinterface.cpp
|
dbinterface.cpp
|
||||||
|
exceptions.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})
|
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})
|
||||||
|
@ -26,6 +26,11 @@ public:
|
|||||||
|
|
||||||
const Type type;
|
const Type type;
|
||||||
|
|
||||||
|
class Duplicate;
|
||||||
|
class DuplicateLogin;
|
||||||
|
class EmptyResult;
|
||||||
|
class NoLogin;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
virtual void connect(const std::string& path) = 0;
|
virtual void connect(const std::string& path) = 0;
|
||||||
virtual void disconnect() = 0;
|
virtual void disconnect() = 0;
|
||||||
@ -37,6 +42,7 @@ public:
|
|||||||
virtual void setVersion(uint8_t version) = 0;
|
virtual void setVersion(uint8_t version) = 0;
|
||||||
|
|
||||||
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
|
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
|
||||||
|
virtual std::string getAccountHash(const std::string& login) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DBInterface(Type type);
|
DBInterface(Type type);
|
||||||
|
20
database/exceptions.cpp
Normal file
20
database/exceptions.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
#include "exceptions.h"
|
||||||
|
|
||||||
|
DBInterface::Duplicate::Duplicate(const std::string& text):
|
||||||
|
std::runtime_error(text)
|
||||||
|
{}
|
||||||
|
|
||||||
|
DBInterface::DuplicateLogin::DuplicateLogin(const std::string& text):
|
||||||
|
Duplicate(text)
|
||||||
|
{}
|
||||||
|
|
||||||
|
DBInterface::EmptyResult::EmptyResult(const std::string& text):
|
||||||
|
std::runtime_error(text)
|
||||||
|
{}
|
||||||
|
|
||||||
|
DBInterface::NoLogin::NoLogin(const std::string& text):
|
||||||
|
EmptyResult(text)
|
||||||
|
{}
|
26
database/exceptions.h
Normal file
26
database/exceptions.h
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "dbinterface.h"
|
||||||
|
|
||||||
|
class DBInterface::Duplicate : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
explicit Duplicate(const std::string& text);
|
||||||
|
};
|
||||||
|
|
||||||
|
class DBInterface::DuplicateLogin : public DBInterface::Duplicate {
|
||||||
|
public:
|
||||||
|
explicit DuplicateLogin(const std::string& text);
|
||||||
|
};
|
||||||
|
|
||||||
|
class DBInterface::EmptyResult : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
explicit EmptyResult(const std::string& text);
|
||||||
|
};
|
||||||
|
|
||||||
|
class DBInterface::NoLogin : public DBInterface::EmptyResult {
|
||||||
|
public:
|
||||||
|
explicit NoLogin(const std::string& text);
|
||||||
|
};
|
@ -10,21 +10,17 @@
|
|||||||
|
|
||||||
#include "statement.h"
|
#include "statement.h"
|
||||||
#include "transaction.h"
|
#include "transaction.h"
|
||||||
|
#include "database/exceptions.h"
|
||||||
|
|
||||||
constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'";
|
constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'";
|
||||||
constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'";
|
constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'";
|
||||||
constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)";
|
constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)";
|
||||||
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
|
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
|
||||||
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
|
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
|
||||||
|
constexpr const char* selectHash = "SELECT password FROM accounts where login = ?";
|
||||||
|
|
||||||
static const std::filesystem::path buildSQLPath = "database";
|
static const std::filesystem::path buildSQLPath = "database";
|
||||||
|
|
||||||
struct ResDeleter {
|
|
||||||
void operator () (MYSQL_RES* res) {
|
|
||||||
mysql_free_result(res);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
MySQL::MySQL():
|
MySQL::MySQL():
|
||||||
DBInterface(Type::mysql),
|
DBInterface(Type::mysql),
|
||||||
connection(),
|
connection(),
|
||||||
@ -231,7 +227,11 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string&
|
|||||||
std::string h = hash;
|
std::string h = hash;
|
||||||
addAcc.bind(l.data(), MYSQL_TYPE_STRING);
|
addAcc.bind(l.data(), MYSQL_TYPE_STRING);
|
||||||
addAcc.bind(h.data(), MYSQL_TYPE_STRING);
|
addAcc.bind(h.data(), MYSQL_TYPE_STRING);
|
||||||
|
try {
|
||||||
addAcc.execute();
|
addAcc.execute();
|
||||||
|
} catch (const Duplicate& dup) {
|
||||||
|
throw DuplicateLogin(dup.what());
|
||||||
|
}
|
||||||
|
|
||||||
unsigned int id = lastInsertedId();
|
unsigned int id = lastInsertedId();
|
||||||
static std::string defaultRole("default");
|
static std::string defaultRole("default");
|
||||||
@ -245,6 +245,24 @@ unsigned int MySQL::registerAccount(const std::string& login, const std::string&
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string MySQL::getAccountHash(const std::string& login) {
|
||||||
|
std::string l = login;
|
||||||
|
MYSQL* con = &connection;
|
||||||
|
|
||||||
|
Statement getHash(con, selectHash);
|
||||||
|
getHash.bind(l.data(), MYSQL_TYPE_STRING);
|
||||||
|
getHash.execute();
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> result = getHash.fetchResult();
|
||||||
|
if (result.empty())
|
||||||
|
throw NoLogin("Couldn't find login " + l);
|
||||||
|
|
||||||
|
if (result[0].empty())
|
||||||
|
throw std::runtime_error("Error with the query \"selectHash\"");
|
||||||
|
|
||||||
|
return result[0][0];
|
||||||
|
}
|
||||||
|
|
||||||
unsigned int MySQL::lastInsertedId() {
|
unsigned int MySQL::lastInsertedId() {
|
||||||
MYSQL* con = &connection;
|
MYSQL* con = &connection;
|
||||||
int result = mysql_query(con, lastIdQuery);
|
int result = mysql_query(con, lastIdQuery);
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
class MySQL : public DBInterface {
|
class MySQL : public DBInterface {
|
||||||
class Statement;
|
class Statement;
|
||||||
class Transaction;
|
class Transaction;
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MySQL();
|
MySQL();
|
||||||
~MySQL() override;
|
~MySQL() override;
|
||||||
@ -29,6 +31,7 @@ public:
|
|||||||
void setVersion(uint8_t version) override;
|
void setVersion(uint8_t version) override;
|
||||||
|
|
||||||
unsigned int registerAccount(const std::string& login, const std::string& hash) override;
|
unsigned int registerAccount(const std::string& login, const std::string& hash) override;
|
||||||
|
std::string getAccountHash(const std::string& login) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void executeFile(const std::filesystem::path& relativePath);
|
void executeFile(const std::filesystem::path& relativePath);
|
||||||
@ -40,4 +43,10 @@ protected:
|
|||||||
std::string login;
|
std::string login;
|
||||||
std::string password;
|
std::string password;
|
||||||
std::string database;
|
std::string database;
|
||||||
|
|
||||||
|
struct ResDeleter {
|
||||||
|
void operator () (MYSQL_RES* res) {
|
||||||
|
mysql_free_result(res);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
@ -5,6 +5,10 @@
|
|||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
|
#include "mysqld_error.h"
|
||||||
|
|
||||||
|
#include "database/exceptions.h"
|
||||||
|
|
||||||
static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME);
|
static uint64_t TIME_LENGTH = sizeof(MYSQL_TIME);
|
||||||
|
|
||||||
MySQL::Statement::Statement(MYSQL* connection, const char* statement):
|
MySQL::Statement::Statement(MYSQL* connection, const char* statement):
|
||||||
@ -44,11 +48,68 @@ void MySQL::Statement::bind(void* value, enum_field_types type, bool usigned) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MySQL::Statement::execute() {
|
void MySQL::Statement::execute() {
|
||||||
int result = mysql_stmt_bind_param(stmt.get(), param.data());
|
MYSQL_STMT* raw = stmt.get();
|
||||||
|
int result = mysql_stmt_bind_param(raw, param.data());
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(stmt.get()));
|
throw std::runtime_error(std::string("Error binding statement: ") + mysql_stmt_error(raw));
|
||||||
|
|
||||||
result = mysql_stmt_execute(stmt.get());
|
result = mysql_stmt_execute(raw);
|
||||||
if (result != 0)
|
if (result != 0) {
|
||||||
throw std::runtime_error(std::string("Error executing statement: ") + mysql_stmt_error(stmt.get()));
|
int errcode = mysql_stmt_errno(raw);
|
||||||
|
std::string text = mysql_stmt_error(raw);
|
||||||
|
switch (errcode) {
|
||||||
|
case ER_DUP_ENTRY:
|
||||||
|
throw Duplicate("Error executing statement: " + text);
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Error executing statement: " + text);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> MySQL::Statement::fetchResult() {
|
||||||
|
MYSQL_STMT* raw = stmt.get();
|
||||||
|
if (mysql_stmt_store_result(raw) != 0)
|
||||||
|
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
|
||||||
|
|
||||||
|
MYSQL_RES* meta = mysql_stmt_result_metadata(raw);
|
||||||
|
if (meta == nullptr)
|
||||||
|
throw std::runtime_error(std::string("Error fetching statement result: ") + mysql_stmt_error(raw)); //TODO not sure if it's valid here
|
||||||
|
|
||||||
|
std::unique_ptr<MYSQL_RES, ResDeleter> mt(meta);
|
||||||
|
unsigned int numColumns = mysql_num_fields(meta);
|
||||||
|
MYSQL_BIND bind[numColumns];
|
||||||
|
memset(bind, 0, sizeof(bind));
|
||||||
|
|
||||||
|
std::vector<std::string> line(numColumns);
|
||||||
|
std::vector<long unsigned int> lengths(numColumns);
|
||||||
|
for (unsigned int i = 0; i < numColumns; ++i) {
|
||||||
|
MYSQL_FIELD *field = mysql_fetch_field_direct(meta, i);
|
||||||
|
|
||||||
|
switch (field->type) {
|
||||||
|
case MYSQL_TYPE_STRING:
|
||||||
|
case MYSQL_TYPE_VAR_STRING:
|
||||||
|
case MYSQL_TYPE_VARCHAR:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Unsupported data fetching statement result " + std::to_string(field->type));
|
||||||
|
}
|
||||||
|
line[i].resize(field->length);
|
||||||
|
bind[i].buffer_type = field->type;
|
||||||
|
bind[i].buffer = line[i].data();
|
||||||
|
bind[i].buffer_length = field->length;
|
||||||
|
bind[i].length = &lengths[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mysql_stmt_bind_result(raw, bind) != 0)
|
||||||
|
throw std::runtime_error(std::string("Error binding on fetching statement result: ") + mysql_stmt_error(raw));
|
||||||
|
|
||||||
|
std::vector<std::vector<std::string>> result;
|
||||||
|
while (mysql_stmt_fetch(raw) == 0) {
|
||||||
|
std::vector<std::string>& row = result.emplace_back(numColumns);
|
||||||
|
for (unsigned int i = 0; i < numColumns; ++i)
|
||||||
|
row[i] = std::string(line[i].data(), lengths[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ public:
|
|||||||
|
|
||||||
void bind(void* value, enum_field_types type, bool usigned = false);
|
void bind(void* value, enum_field_types type, bool usigned = false);
|
||||||
void execute();
|
void execute();
|
||||||
|
std::vector<std::vector<std::string>> fetchResult();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt;
|
std::unique_ptr<MYSQL_STMT, STMTDeleter> stmt;
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
#include "transaction.h"
|
#include "transaction.h"
|
||||||
|
|
||||||
MySQL::Transaction::Transaction(MYSQL* connection):
|
MySQL::Transaction::Transaction(MYSQL* connection):
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mysql.h"
|
#include "mysql.h"
|
||||||
|
@ -3,6 +3,7 @@ set(HEADERS
|
|||||||
info.h
|
info.h
|
||||||
env.h
|
env.h
|
||||||
register.h
|
register.h
|
||||||
|
login.h
|
||||||
)
|
)
|
||||||
|
|
||||||
set(SOURCES
|
set(SOURCES
|
||||||
@ -10,6 +11,7 @@ set(SOURCES
|
|||||||
info.cpp
|
info.cpp
|
||||||
env.cpp
|
env.cpp
|
||||||
register.cpp
|
register.cpp
|
||||||
|
login.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})
|
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})
|
||||||
|
@ -11,7 +11,7 @@ void Handler::Env::handle(Request& request) {
|
|||||||
nlohmann::json body = nlohmann::json::object();
|
nlohmann::json body = nlohmann::json::object();
|
||||||
request.printEnvironment(body);
|
request.printEnvironment(body);
|
||||||
|
|
||||||
Response res(request);
|
Response& res = request.createResponse();
|
||||||
res.setBody(body);
|
res.setBody(body);
|
||||||
res.send();
|
res.send();
|
||||||
}
|
}
|
||||||
|
@ -7,10 +7,10 @@
|
|||||||
|
|
||||||
namespace Handler {
|
namespace Handler {
|
||||||
|
|
||||||
class Env : public Handler::Handler {
|
class Env : public Handler {
|
||||||
public:
|
public:
|
||||||
Env();
|
Env();
|
||||||
virtual void handle(Request& request);
|
void handle(Request& request) override;
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ Handler::Info::Info():
|
|||||||
{}
|
{}
|
||||||
|
|
||||||
void Handler::Info::handle(Request& request) {
|
void Handler::Info::handle(Request& request) {
|
||||||
Response res(request);
|
Response& res = request.createResponse();
|
||||||
nlohmann::json body = nlohmann::json::object();
|
nlohmann::json body = nlohmann::json::object();
|
||||||
body["type"] = PROJECT_NAME;
|
body["type"] = PROJECT_NAME;
|
||||||
body["version"] = PROJECT_VERSION;
|
body["version"] = PROJECT_VERSION;
|
||||||
|
65
handler/login.cpp
Normal file
65
handler/login.cpp
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
#include "login.h"
|
||||||
|
|
||||||
|
#include "server/server.h"
|
||||||
|
#include "database/exceptions.h"
|
||||||
|
|
||||||
|
Handler::Login::Login(Server* server):
|
||||||
|
Handler("login", Request::Method::post),
|
||||||
|
server(server)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void Handler::Login::handle(Request& request) {
|
||||||
|
std::map form = request.getForm();
|
||||||
|
std::map<std::string, std::string>::const_iterator itr = form.find("login");
|
||||||
|
if (itr == form.end())
|
||||||
|
return error(request, Result::noLogin, Response::Status::badRequest);
|
||||||
|
|
||||||
|
const std::string& login = itr->second;
|
||||||
|
if (login.empty())
|
||||||
|
return error(request, Result::emptyLogin, Response::Status::badRequest);
|
||||||
|
|
||||||
|
itr = form.find("password");
|
||||||
|
if (itr == form.end())
|
||||||
|
return error(request, Result::noPassword, Response::Status::badRequest);
|
||||||
|
|
||||||
|
const std::string& password = itr->second;
|
||||||
|
if (password.empty())
|
||||||
|
return error(request, Result::emptyPassword, Response::Status::badRequest);
|
||||||
|
|
||||||
|
bool success = false;
|
||||||
|
try {
|
||||||
|
success = server->validatePassword(login, password);
|
||||||
|
} catch (const DBInterface::NoLogin& e) {
|
||||||
|
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||||
|
return error(request, Result::noLogin, Response::Status::badRequest); //can send unauthed instead, to exclude login spoofing
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||||
|
return error(request, Result::unknownError, Response::Status::internalError);
|
||||||
|
} catch (...) {
|
||||||
|
std::cerr << "Unknown exception on registration" << std::endl;
|
||||||
|
return error(request, Result::unknownError, Response::Status::internalError);
|
||||||
|
}
|
||||||
|
if (!success)
|
||||||
|
return error(request, Result::noLogin, Response::Status::badRequest);
|
||||||
|
|
||||||
|
//TODO opening the session
|
||||||
|
|
||||||
|
Response& res = request.createResponse();
|
||||||
|
nlohmann::json body = nlohmann::json::object();
|
||||||
|
body["result"] = Result::success;
|
||||||
|
|
||||||
|
res.setBody(body);
|
||||||
|
res.send();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Handler::Login::error(Request& request, Result result, Response::Status code) {
|
||||||
|
Response& res = request.createResponse(code);
|
||||||
|
nlohmann::json body = nlohmann::json::object();
|
||||||
|
body["result"] = result;
|
||||||
|
|
||||||
|
res.setBody(body);
|
||||||
|
res.send();
|
||||||
|
}
|
32
handler/login.h
Normal file
32
handler/login.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
|
||||||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "handler.h"
|
||||||
|
|
||||||
|
class Server;
|
||||||
|
namespace Handler {
|
||||||
|
|
||||||
|
class Login : public Handler {
|
||||||
|
public:
|
||||||
|
Login(Server* server);
|
||||||
|
void handle(Request& request) override;
|
||||||
|
|
||||||
|
enum class Result {
|
||||||
|
success,
|
||||||
|
noLogin,
|
||||||
|
emptyLogin,
|
||||||
|
noPassword,
|
||||||
|
emptyPassword,
|
||||||
|
unknownError
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
void error(Request& request, Result result, Response::Status code);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Server* server;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
@ -4,6 +4,7 @@
|
|||||||
#include "register.h"
|
#include "register.h"
|
||||||
|
|
||||||
#include "server/server.h"
|
#include "server/server.h"
|
||||||
|
#include "database/exceptions.h"
|
||||||
|
|
||||||
Handler::Register::Register(Server* server):
|
Handler::Register::Register(Server* server):
|
||||||
Handler("register", Request::Method::post),
|
Handler("register", Request::Method::post),
|
||||||
@ -14,35 +15,38 @@ void Handler::Register::handle(Request& request) {
|
|||||||
std::map form = request.getForm();
|
std::map form = request.getForm();
|
||||||
std::map<std::string, std::string>::const_iterator itr = form.find("login");
|
std::map<std::string, std::string>::const_iterator itr = form.find("login");
|
||||||
if (itr == form.end())
|
if (itr == form.end())
|
||||||
return error(request, Result::noLogin);
|
return error(request, Result::noLogin, Response::Status::badRequest);
|
||||||
|
|
||||||
const std::string& login = itr->second;
|
const std::string& login = itr->second;
|
||||||
if (login.empty())
|
if (login.empty())
|
||||||
return error(request, Result::emptyLogin);
|
return error(request, Result::emptyLogin, Response::Status::badRequest);
|
||||||
|
|
||||||
//TODO login policies checkup
|
//TODO login policies checkup
|
||||||
|
|
||||||
itr = form.find("password");
|
itr = form.find("password");
|
||||||
if (itr == form.end())
|
if (itr == form.end())
|
||||||
return error(request, Result::noPassword);
|
return error(request, Result::noPassword, Response::Status::badRequest);
|
||||||
|
|
||||||
const std::string& password = itr->second;
|
const std::string& password = itr->second;
|
||||||
if (password.empty())
|
if (password.empty())
|
||||||
return error(request, Result::emptyPassword);
|
return error(request, Result::emptyPassword, Response::Status::badRequest);
|
||||||
|
|
||||||
//TODO password policies checkup
|
//TODO password policies checkup
|
||||||
|
|
||||||
try {
|
try {
|
||||||
server->registerAccount(login, password);
|
server->registerAccount(login, password);
|
||||||
|
} catch (const DBInterface::DuplicateLogin& e) {
|
||||||
|
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||||
|
return error(request, Result::loginExists, Response::Status::conflict);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
std::cerr << "Exception on registration:\n\t" << e.what() << std::endl;
|
||||||
return error(request, Result::unknownError);
|
return error(request, Result::unknownError, Response::Status::internalError);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
std::cerr << "Unknown exception on registration" << std::endl;
|
std::cerr << "Unknown exception on registration" << std::endl;
|
||||||
return error(request, Result::unknownError);
|
return error(request, Result::unknownError, Response::Status::internalError);
|
||||||
}
|
}
|
||||||
|
|
||||||
Response res(request);
|
Response& res = request.createResponse();
|
||||||
nlohmann::json body = nlohmann::json::object();
|
nlohmann::json body = nlohmann::json::object();
|
||||||
body["result"] = Result::success;
|
body["result"] = Result::success;
|
||||||
|
|
||||||
@ -50,8 +54,8 @@ void Handler::Register::handle(Request& request) {
|
|||||||
res.send();
|
res.send();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Handler::Register::error(Request& request, Result result) {
|
void Handler::Register::error(Request& request, Result result, Response::Status code) {
|
||||||
Response res(request);
|
Response& res = request.createResponse(code);
|
||||||
nlohmann::json body = nlohmann::json::object();
|
nlohmann::json body = nlohmann::json::object();
|
||||||
body["result"] = result;
|
body["result"] = result;
|
||||||
|
|
||||||
|
@ -8,10 +8,10 @@
|
|||||||
class Server;
|
class Server;
|
||||||
namespace Handler {
|
namespace Handler {
|
||||||
|
|
||||||
class Register : public Handler::Handler {
|
class Register : public Handler {
|
||||||
public:
|
public:
|
||||||
Register(Server* server);
|
Register(Server* server);
|
||||||
virtual void handle(Request& request);
|
void handle(Request& request) override;
|
||||||
|
|
||||||
enum class Result {
|
enum class Result {
|
||||||
success,
|
success,
|
||||||
@ -26,7 +26,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void error(Request& request, Result result);
|
void error(Request& request, Result result, Response::Status code);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Server* server;
|
Server* server;
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
|
|
||||||
#include "response/response.h"
|
#include "response/response.h"
|
||||||
|
|
||||||
constexpr static const char* GET("GET");
|
|
||||||
|
|
||||||
constexpr static const char* REQUEST_METHOD("REQUEST_METHOD");
|
constexpr static const char* REQUEST_METHOD("REQUEST_METHOD");
|
||||||
constexpr static const char* SCRIPT_FILENAME("SCRIPT_FILENAME");
|
constexpr static const char* SCRIPT_FILENAME("SCRIPT_FILENAME");
|
||||||
constexpr static const char* SERVER_NAME("SERVER_NAME");
|
constexpr static const char* SERVER_NAME("SERVER_NAME");
|
||||||
@ -51,11 +49,15 @@ void Request::terminate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Request::Method Request::method() const {
|
std::string_view Request::methodName() const {
|
||||||
if (state == State::initial)
|
if (state == State::initial)
|
||||||
throw std::runtime_error("An attempt to read request method on not accepted request");
|
throw std::runtime_error("An attempt to read request method on not accepted request");
|
||||||
|
|
||||||
std::string_view method(FCGX_GetParam(REQUEST_METHOD, raw.envp));
|
return FCGX_GetParam(REQUEST_METHOD, raw.envp);
|
||||||
|
}
|
||||||
|
|
||||||
|
Request::Method Request::method() const {
|
||||||
|
std::string_view method = methodName();
|
||||||
for (const auto& pair : methods) {
|
for (const auto& pair : methods) {
|
||||||
if (pair.first == method)
|
if (pair.first == method)
|
||||||
return pair.second;
|
return pair.second;
|
||||||
@ -79,17 +81,42 @@ bool Request::wait(int socketDescriptor) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
OStream Request::getOutputStream(const Response* response) {
|
OStream Request::getOutputStream() {
|
||||||
validateResponse(response);
|
|
||||||
return OStream(raw.out);
|
return OStream(raw.out);
|
||||||
}
|
}
|
||||||
|
|
||||||
OStream Request::getErrorStream(const Response* response) {
|
OStream Request::getErrorStream() {
|
||||||
validateResponse(response);
|
|
||||||
return OStream(raw.err);
|
return OStream(raw.err);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Request::responseIsComplete(const Response* response) {
|
Response& Request::createResponse() {
|
||||||
|
if (state != State::accepted)
|
||||||
|
throw std::runtime_error("An attempt create response to the request in the wrong state");
|
||||||
|
|
||||||
|
response = std::unique_ptr<Response>(new Response(*this));
|
||||||
|
state = State::responding;
|
||||||
|
|
||||||
|
return *response.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
Response& Request::createResponse(Response::Status status) {
|
||||||
|
if (state != State::accepted)
|
||||||
|
throw std::runtime_error("An attempt create response to the request in the wrong state");
|
||||||
|
|
||||||
|
response = std::unique_ptr<Response>(new Response(*this, status));
|
||||||
|
state = State::responding;
|
||||||
|
|
||||||
|
return *response.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t Request::responseCode() const {
|
||||||
|
if (state != State::responded)
|
||||||
|
throw std::runtime_error("An attempt create read response code on the wrong state");
|
||||||
|
|
||||||
|
return response->statusCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Request::responseIsComplete() {
|
||||||
switch (state) {
|
switch (state) {
|
||||||
case State::initial:
|
case State::initial:
|
||||||
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't even accepted yet");
|
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't even accepted yet");
|
||||||
@ -98,10 +125,6 @@ void Request::responseIsComplete(const Response* response) {
|
|||||||
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't responded");
|
throw std::runtime_error("An attempt to mark the request as complete, but it wasn't responded");
|
||||||
break;
|
break;
|
||||||
case State::responding:
|
case State::responding:
|
||||||
if (Request::response != response)
|
|
||||||
throw std::runtime_error("An attempt to mark the request as complete by the different response who actually started responding");
|
|
||||||
|
|
||||||
Request::response = nullptr;
|
|
||||||
state = State::responded;
|
state = State::responded;
|
||||||
break;
|
break;
|
||||||
case State::responded:
|
case State::responded:
|
||||||
@ -110,26 +133,6 @@ void Request::responseIsComplete(const Response* response) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Request::validateResponse(const Response* response) {
|
|
||||||
switch (state) {
|
|
||||||
case State::initial:
|
|
||||||
throw std::runtime_error("An attempt to request stream while the request wasn't even accepted yet");
|
|
||||||
break;
|
|
||||||
case State::accepted:
|
|
||||||
Request::response = response;
|
|
||||||
state = State::responding;
|
|
||||||
break;
|
|
||||||
case State::responding:
|
|
||||||
if (Request::response != response)
|
|
||||||
throw std::runtime_error("Error handling a request: first time one response started replying, then another continued");
|
|
||||||
|
|
||||||
break;
|
|
||||||
case State::responded:
|
|
||||||
throw std::runtime_error("An attempt to request stream on a request that was already done responding");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Request::State Request::currentState() const {
|
Request::State Request::currentState() const {
|
||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
|
@ -16,9 +16,10 @@
|
|||||||
|
|
||||||
#include "stream/ostream.h"
|
#include "stream/ostream.h"
|
||||||
#include "utils/formdecode.h"
|
#include "utils/formdecode.h"
|
||||||
|
#include "response/response.h"
|
||||||
|
|
||||||
class Response;
|
|
||||||
class Request {
|
class Request {
|
||||||
|
friend class Response;
|
||||||
public:
|
public:
|
||||||
enum class State {
|
enum class State {
|
||||||
initial,
|
initial,
|
||||||
@ -43,26 +44,29 @@ public:
|
|||||||
bool wait(int socketDescriptor);
|
bool wait(int socketDescriptor);
|
||||||
void terminate();
|
void terminate();
|
||||||
|
|
||||||
|
Response& createResponse();
|
||||||
|
Response& createResponse(Response::Status status);
|
||||||
|
|
||||||
|
uint16_t responseCode() const;
|
||||||
Method method() const;
|
Method method() const;
|
||||||
|
std::string_view methodName() const;
|
||||||
State currentState() const;
|
State currentState() const;
|
||||||
bool isFormUrlEncoded() const;
|
bool isFormUrlEncoded() const;
|
||||||
unsigned int contentLength() const;
|
unsigned int contentLength() const;
|
||||||
std::map<std::string, std::string> getForm() const;
|
std::map<std::string, std::string> getForm() const;
|
||||||
|
|
||||||
OStream getOutputStream(const Response* response);
|
|
||||||
OStream getErrorStream(const Response* response);
|
|
||||||
void responseIsComplete(const Response* response);
|
|
||||||
|
|
||||||
std::string getPath(const std::string& serverName) const;
|
std::string getPath(const std::string& serverName) const;
|
||||||
std::string getServerName() const;
|
std::string getServerName() const;
|
||||||
void printEnvironment(std::ostream& out);
|
void printEnvironment(std::ostream& out);
|
||||||
void printEnvironment(nlohmann::json& out);
|
void printEnvironment(nlohmann::json& out);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void validateResponse(const Response* response);
|
OStream getOutputStream();
|
||||||
|
OStream getErrorStream();
|
||||||
|
void responseIsComplete();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
State state;
|
State state;
|
||||||
FCGX_Request raw;
|
FCGX_Request raw;
|
||||||
const Response* response;
|
std::unique_ptr<Response> response;
|
||||||
};
|
};
|
||||||
|
@ -3,10 +3,25 @@
|
|||||||
|
|
||||||
#include "response.h"
|
#include "response.h"
|
||||||
|
|
||||||
constexpr std::array<std::string_view, static_cast<uint8_t>(Response::Status::__size)> statusCodes = {
|
#include "request/request.h"
|
||||||
|
|
||||||
|
constexpr std::array<uint16_t, static_cast<uint8_t>(Response::Status::__size)> statusCodes = {
|
||||||
|
200,
|
||||||
|
400,
|
||||||
|
401,
|
||||||
|
404,
|
||||||
|
405,
|
||||||
|
409,
|
||||||
|
500
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr std::array<std::string_view, static_cast<uint8_t>(Response::Status::__size)> statuses = {
|
||||||
"Status: 200 OK",
|
"Status: 200 OK",
|
||||||
|
"Status: 400 Bad Request",
|
||||||
|
"Status: 401 Unauthorized",
|
||||||
"Status: 404 Not Found",
|
"Status: 404 Not Found",
|
||||||
"Status: 405 Method Not Allowed",
|
"Status: 405 Method Not Allowed",
|
||||||
|
"Status: 409 Conflict",
|
||||||
"Status: 500 Internal Error"
|
"Status: 500 Internal Error"
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -33,9 +48,9 @@ void Response::send() const {
|
|||||||
// OStream out = status == Status::ok ?
|
// OStream out = status == Status::ok ?
|
||||||
// request.getOutputStream() :
|
// request.getOutputStream() :
|
||||||
// request.getErrorStream();
|
// request.getErrorStream();
|
||||||
OStream out = request.getOutputStream(this);
|
OStream out = request.getOutputStream();
|
||||||
|
|
||||||
out << statusCodes[static_cast<uint8_t>(status)];
|
out << statuses[static_cast<uint8_t>(status)];
|
||||||
if (!body.empty())
|
if (!body.empty())
|
||||||
out << '\n'
|
out << '\n'
|
||||||
<< contentTypes[static_cast<uint8_t>(type)]
|
<< contentTypes[static_cast<uint8_t>(type)]
|
||||||
@ -43,7 +58,11 @@ void Response::send() const {
|
|||||||
<< '\n'
|
<< '\n'
|
||||||
<< body;
|
<< body;
|
||||||
|
|
||||||
request.responseIsComplete(this);
|
request.responseIsComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t Response::statusCode() const {
|
||||||
|
return statusCodes[static_cast<uint8_t>(status)];
|
||||||
}
|
}
|
||||||
|
|
||||||
void Response::setBody(const std::string& body) {
|
void Response::setBody(const std::string& body) {
|
||||||
|
@ -9,15 +9,20 @@
|
|||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
#include "request/request.h"
|
|
||||||
#include "stream/ostream.h"
|
#include "stream/ostream.h"
|
||||||
|
|
||||||
|
class Request;
|
||||||
class Response {
|
class Response {
|
||||||
|
friend class Request;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
enum class Status {
|
enum class Status {
|
||||||
ok,
|
ok,
|
||||||
|
badRequest,
|
||||||
|
unauthorized,
|
||||||
notFound,
|
notFound,
|
||||||
methodNotAllowed,
|
methodNotAllowed,
|
||||||
|
conflict,
|
||||||
internalError,
|
internalError,
|
||||||
__size
|
__size
|
||||||
};
|
};
|
||||||
@ -27,13 +32,17 @@ public:
|
|||||||
json,
|
json,
|
||||||
__size
|
__size
|
||||||
};
|
};
|
||||||
Response(Request& request);
|
|
||||||
Response(Request& request, Status status);
|
uint16_t statusCode() const;
|
||||||
|
|
||||||
void send() const;
|
void send() const;
|
||||||
void setBody(const std::string& body);
|
void setBody(const std::string& body);
|
||||||
void setBody(const nlohmann::json& body);
|
void setBody(const nlohmann::json& body);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Response(Request& request);
|
||||||
|
Response(Request& request, Status status);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Request& request;
|
Request& request;
|
||||||
Status status;
|
Status status;
|
||||||
|
@ -50,29 +50,29 @@ void Router::route(const std::string& path, std::unique_ptr<Request> request) {
|
|||||||
if (request->currentState() != Request::State::responded)
|
if (request->currentState() != Request::State::responded)
|
||||||
handleInternalError(path, std::runtime_error("handler failed to handle the request"), std::move(request));
|
handleInternalError(path, std::runtime_error("handler failed to handle the request"), std::move(request));
|
||||||
else
|
else
|
||||||
std::cout << "Success:\t" << path << std::endl;
|
std::cout << request->responseCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
handleInternalError(path, e, std::move(request));
|
handleInternalError(path, e, std::move(request));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Router::handleNotFound(const std::string& path, std::unique_ptr<Request> request) {
|
void Router::handleNotFound(const std::string& path, std::unique_ptr<Request> request) {
|
||||||
Response notFound(*request.get(), Response::Status::notFound);
|
Response& notFound = request->createResponse(Response::Status::notFound);
|
||||||
notFound.setBody(std::string("Path \"") + path + "\" was not found");
|
notFound.setBody(std::string("Path \"") + path + "\" was not found");
|
||||||
notFound.send();
|
notFound.send();
|
||||||
std::cerr << "Not found:\t" << path << std::endl;
|
std::cerr << notFound.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Router::handleInternalError(const std::string& path, const std::exception& exception, std::unique_ptr<Request> request) {
|
void Router::handleInternalError(const std::string& path, const std::exception& exception, std::unique_ptr<Request> request) {
|
||||||
Response error(*request.get(), Response::Status::internalError);
|
Response& error = request->createResponse(Response::Status::internalError);
|
||||||
error.setBody(std::string(exception.what()));
|
error.setBody(std::string(exception.what()));
|
||||||
error.send();
|
error.send();
|
||||||
std::cerr << "Internal error:\t" << path << "\n\t" << exception.what() << std::endl;
|
std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Router::handleMethodNotAllowed(const std::string& path, std::unique_ptr<Request> request) {
|
void Router::handleMethodNotAllowed(const std::string& path, std::unique_ptr<Request> request) {
|
||||||
Response error(*request.get(), Response::Status::methodNotAllowed);
|
Response& error = request->createResponse(Response::Status::methodNotAllowed);
|
||||||
error.setBody(std::string("Method not allowed"));
|
error.setBody(std::string("Method not allowed"));
|
||||||
error.send();
|
error.send();
|
||||||
std::cerr << "Method not allowed:\t" << path << std::endl;
|
std::cerr << error.statusCode() << '\t' << request->methodName() << '\t' << path << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "handler/info.h"
|
#include "handler/info.h"
|
||||||
#include "handler/env.h"
|
#include "handler/env.h"
|
||||||
#include "handler/register.h"
|
#include "handler/register.h"
|
||||||
|
#include "handler/login.h"
|
||||||
|
|
||||||
constexpr const char* pepper = "well, not much of a secret, huh?";
|
constexpr const char* pepper = "well, not much of a secret, huh?";
|
||||||
constexpr uint8_t currentDbVesion = 1;
|
constexpr uint8_t currentDbVesion = 1;
|
||||||
@ -39,6 +40,7 @@ Server::Server():
|
|||||||
router.addRoute(std::make_unique<Handler::Info>());
|
router.addRoute(std::make_unique<Handler::Info>());
|
||||||
router.addRoute(std::make_unique<Handler::Env>());
|
router.addRoute(std::make_unique<Handler::Env>());
|
||||||
router.addRoute(std::make_unique<Handler::Register>(this));
|
router.addRoute(std::make_unique<Handler::Register>(this));
|
||||||
|
router.addRoute(std::make_unique<Handler::Login>(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
Server::~Server() {}
|
Server::~Server() {}
|
||||||
@ -63,7 +65,7 @@ void Server::handleRequest(std::unique_ptr<Request> request) {
|
|||||||
std::cout << "received server name " << serverName.value() << std::endl;
|
std::cout << "received server name " << serverName.value() << std::endl;
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
std::cerr << "failed to read server name" << std::endl;
|
std::cerr << "failed to read server name" << std::endl;
|
||||||
Response error(*request.get(), Response::Status::internalError);
|
Response& error = request->createResponse(Response::Status::internalError);
|
||||||
error.send();
|
error.send();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -107,3 +109,19 @@ unsigned int Server::registerAccount(const std::string& login, const std::string
|
|||||||
|
|
||||||
return db->registerAccount(login, hash);
|
return db->registerAccount(login, hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Server::validatePassword(const std::string& login, const std::string& password) {
|
||||||
|
std::string hash = db->getAccountHash(login);
|
||||||
|
|
||||||
|
std::string spiced = password + pepper;
|
||||||
|
int result = argon2id_verify(hash.data(), spiced.data(), spiced.size());
|
||||||
|
|
||||||
|
switch (result) {
|
||||||
|
case ARGON2_OK:
|
||||||
|
return true;
|
||||||
|
case ARGON2_VERIFY_MISMATCH:
|
||||||
|
return false;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(std::string("Failed to verify password: ") + argon2_error_message(result));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -32,6 +32,7 @@ public:
|
|||||||
void run(int socketDescriptor);
|
void run(int socketDescriptor);
|
||||||
|
|
||||||
unsigned int registerAccount(const std::string& login, const std::string& password);
|
unsigned int registerAccount(const std::string& login, const std::string& password);
|
||||||
|
bool validatePassword(const std::string& login, const std::string& password);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void handleRequest(std::unique_ptr<Request> request);
|
void handleRequest(std::unique_ptr<Request> request);
|
||||||
|
Loading…
Reference in New Issue
Block a user