111 lines
3.1 KiB
C++
111 lines
3.1 KiB
C++
#include "server.h"
|
|
|
|
Server::Server():
|
|
terminating(false),
|
|
requestCount(0),
|
|
serverName(std::nullopt),
|
|
router(),
|
|
db()
|
|
{
|
|
std::cout << "Startig pica..." << std::endl;
|
|
|
|
db = DBInterface::create(DBInterface::Type::mysql);
|
|
std::cout << "Database type: MySQL" << std::endl;
|
|
|
|
db->setCredentials("pica", "pica");
|
|
db->setDatabase("pica");
|
|
|
|
bool connected = false;
|
|
try {
|
|
db->connect("/run/mysqld/mysqld.sock");
|
|
connected = true;
|
|
std::cout << "Successfully connected to the database" << std::endl;
|
|
|
|
} catch (const std::runtime_error& e) {
|
|
std::cerr << "Couldn't connect to the database: " << e.what() << std::endl;
|
|
}
|
|
|
|
if (connected) {
|
|
uint8_t version = db->getVersion();
|
|
std::cout << "Database version is " << std::to_string(version) << std::endl;
|
|
if (version == 0) {
|
|
db->executeFile("database/migrations/m0.sql");
|
|
std::cout << "Successfully migrated to version 1" << std::endl;
|
|
db->setVersion(1);
|
|
std::cout << "Database version is " << std::to_string(db->getVersion()) << " now" << std::endl;
|
|
}
|
|
}
|
|
|
|
router.addRoute("info", Server::info);
|
|
router.addRoute("env", Server::printEnvironment);
|
|
}
|
|
|
|
Server::~Server() {}
|
|
|
|
void Server::run(int socketDescriptor) {
|
|
while (!terminating) {
|
|
std::unique_ptr<Request> request = std::make_unique<Request>();
|
|
bool result = request->wait(socketDescriptor);
|
|
if (!result) {
|
|
std::cerr << "Error accepting a request" << std::endl;
|
|
return;
|
|
}
|
|
handleRequest(std::move(request));
|
|
}
|
|
}
|
|
|
|
void Server::handleRequest(std::unique_ptr<Request> request) {
|
|
++requestCount;
|
|
if (!serverName) {
|
|
try {
|
|
serverName = request->getServerName();
|
|
std::cout << "received server name " << serverName.value() << std::endl;
|
|
} catch (...) {
|
|
std::cerr << "failed to read server name" << std::endl;
|
|
Response error(Response::Status::internalError);
|
|
error.replyTo(*request.get());
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (!request->isGet()) {
|
|
static const Response methodNotAllowed(Response::Status::methodNotAllowed);
|
|
methodNotAllowed.replyTo(*request.get());
|
|
return;
|
|
}
|
|
|
|
try {
|
|
std::string path = request->getPath(serverName.value());
|
|
router.route(path.data(), std::move(request), this);
|
|
} catch (const std::exception e) {
|
|
Response error(Response::Status::internalError);
|
|
error.setBody(std::string(e.what()));
|
|
error.replyTo(*request.get());
|
|
}
|
|
}
|
|
|
|
bool Server::printEnvironment(Request* request, Server* server) {
|
|
(void)server;
|
|
nlohmann::json body = nlohmann::json::object();
|
|
request->printEnvironment(body);
|
|
|
|
Response res;
|
|
res.setBody(body);
|
|
res.replyTo(*request);
|
|
|
|
return true;
|
|
}
|
|
|
|
bool Server::info(Request* request, Server* server) {
|
|
(void)server;
|
|
Response res;
|
|
nlohmann::json body = nlohmann::json::object();
|
|
body["type"] = "Pica";
|
|
body["version"] = "0.0.1";
|
|
|
|
res.setBody(body);
|
|
res.replyTo(*request);
|
|
|
|
return true;
|
|
}
|