diff --git a/database/CMakeLists.txt b/database/CMakeLists.txt index 42e02ef..36e1ad5 100644 --- a/database/CMakeLists.txt +++ b/database/CMakeLists.txt @@ -19,3 +19,4 @@ target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) add_subdirectory(mysql) add_subdirectory(migrations) +add_subdirectory(schema) diff --git a/database/interface.h b/database/interface.h index b2cc592..e96f97b 100644 --- a/database/interface.h +++ b/database/interface.h @@ -9,26 +9,10 @@ #include #include +#include "schema/session.h" +#include "schema/asset.h" + namespace DB { -struct Session { - unsigned int id; - unsigned int owner; - std::string accessToken; - std::string renewToken; -}; - -struct Asset { - unsigned int id; - unsigned int owner; - unsigned int currency; - std::string title; - std::string icon; - // `color` INTEGER UNSIGNED DEFAULT 0, - // `balance` DECIMAL (20, 5) DEFAULT 0, - // `type` INTEGER UNSIGNED NOT NULL, - bool archived; -}; - class Interface { public: enum class Type { @@ -62,6 +46,7 @@ public: virtual Session createSession(const std::string& login, const std::string& access, const std::string& renew) = 0; virtual Session findSession(const std::string& accessToken) = 0; virtual std::vector listAssets(unsigned int owner) = 0; + virtual Asset addAsset(const Asset& asset) = 0; protected: Interface(Type type); diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index e6b3067..6a807fb 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -21,12 +21,14 @@ constexpr const char* selectHash = "SELECT password FROM accounts where login = constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `access`, `renew`, `persist`, `device`)" " SELECT accounts.id, ?, ?, true, ? FROM accounts WHERE accounts.login = ?" " RETURNING id, owner"; -constexpr const char* selectSession = "SELECT id, owner, renew FROM sessions where access = ?"; +constexpr const char* selectSession = "SELECT id, owner, access, renew FROM sessions where access = ?"; constexpr const char* selectAssets = "SELECT id, owner, currency, title, icon, archived FROM assets where owner = ?"; +constexpr const char* insertAsset = "INSERT INTO assets (`owner`, `currency`, `title`, `icon`, `archived`)" + " VALUES (?, ?, ?, ?, ?)"; static const std::filesystem::path buildSQLPath = "database"; -DB::MySQL::MySQL(): +DB::MySQL::MySQL (): Interface(Type::mysql), connection(), login(), @@ -40,7 +42,7 @@ DB::MySQL::~MySQL() { mysql_close(&connection); } -void DB::MySQL::connect(const std::string& path) { +void DB::MySQL::connect (const std::string& path) { if (state != State::disconnected) return; @@ -62,7 +64,7 @@ void DB::MySQL::connect(const std::string& path) { state = State::connected; } -void DB::MySQL::setCredentials(const std::string& login, const std::string& password) { +void DB::MySQL::setCredentials (const std::string& login, const std::string& password) { if (MySQL::login == login && MySQL::password == password) return; @@ -84,7 +86,7 @@ void DB::MySQL::setCredentials(const std::string& login, const std::string& pass throw std::runtime_error(std::string("Error changing credetials: ") + mysql_error(con)); } -void DB::MySQL::setDatabase(const std::string& database) { +void DB::MySQL::setDatabase (const std::string& database) { if (MySQL::database == database) return; @@ -100,7 +102,7 @@ void DB::MySQL::setDatabase(const std::string& database) { throw std::runtime_error(std::string("Error changing db: ") + mysql_error(con)); } -void DB::MySQL::disconnect() { +void DB::MySQL::disconnect () { if (state == State::disconnected) return; @@ -109,7 +111,7 @@ void DB::MySQL::disconnect() { mysql_init(con); //this is ridiculous! } -void DB::MySQL::executeFile(const std::filesystem::path& relativePath) { +void DB::MySQL::executeFile (const std::filesystem::path& relativePath) { MYSQL* con = &connection; std::filesystem::path path = sharedPath() / relativePath; if (!std::filesystem::exists(path)) @@ -141,7 +143,7 @@ void DB::MySQL::executeFile(const std::filesystem::path& relativePath) { } } -uint8_t DB::MySQL::getVersion() { +uint8_t DB::MySQL::getVersion () { MYSQL* con = &connection; int result = mysql_query(con, versionQuery); @@ -164,14 +166,14 @@ uint8_t DB::MySQL::getVersion() { return 0; } -void DB::MySQL::setVersion(uint8_t version) { +void DB::MySQL::setVersion (uint8_t version) { std::string strVersion = std::to_string(version); Statement statement(&connection, updateQuery); statement.bind(strVersion.data(), MYSQL_TYPE_VAR_STRING); statement.execute(); } -void DB::MySQL::migrate(uint8_t targetVersion) { +void DB::MySQL::migrate (uint8_t targetVersion) { uint8_t currentVersion = getVersion(); while (currentVersion < targetVersion) { @@ -193,7 +195,7 @@ void DB::MySQL::migrate(uint8_t targetVersion) { std::cout << "Database is now on actual version " << std::to_string(targetVersion) << std::endl; } -std::optional DB::MySQL::getComment(std::string& string) { +std::optional DB::MySQL::getComment (std::string& string) { ltrim(string); if (string.length() < 2) return std::nullopt; @@ -221,7 +223,7 @@ std::optional DB::MySQL::getComment(std::string& string) { return std::nullopt; } -unsigned int DB::MySQL::registerAccount(const std::string& login, const std::string& hash) { +unsigned int DB::MySQL::registerAccount (const std::string& login, const std::string& hash) { //TODO validate filed lengths! MYSQL* con = &connection; MySQL::Transaction txn(con); @@ -250,7 +252,7 @@ unsigned int DB::MySQL::registerAccount(const std::string& login, const std::str return id; } -std::string DB::MySQL::getAccountHash(const std::string& login) { +std::string DB::MySQL::getAccountHash (const std::string& login) { std::string l = login; MYSQL* con = &connection; @@ -268,7 +270,7 @@ std::string DB::MySQL::getAccountHash(const std::string& login) { return std::any_cast(result[0][0]); } -DB::Session DB::MySQL::createSession(const std::string& login, const std::string& access, const std::string& renew) { +DB::Session DB::MySQL::createSession (const std::string& login, const std::string& access, const std::string& renew) { std::string l = login; DB::Session res; res.accessToken = access; @@ -294,7 +296,7 @@ DB::Session DB::MySQL::createSession(const std::string& login, const std::string return res; } -unsigned int DB::MySQL::lastInsertedId() { +unsigned int DB::MySQL::lastInsertedId () { MYSQL* con = &connection; int result = mysql_query(con, lastIdQuery); @@ -311,7 +313,7 @@ unsigned int DB::MySQL::lastInsertedId() { else throw std::runtime_error(std::string("Querying last inserted id returned no rows")); } -DB::Session DB::MySQL::findSession(const std::string& accessToken) { +DB::Session DB::MySQL::findSession (const std::string& accessToken) { std::string a = accessToken; MYSQL* con = &connection; @@ -323,36 +325,38 @@ DB::Session DB::MySQL::findSession(const std::string& accessToken) { if (result.empty()) throw NoSession("Couldn't find session with token " + a); - DB::Session res; - res.id = std::any_cast(result[0][0]); - res.owner = std::any_cast(result[0][1]); - res.renewToken = std::any_cast(result[0][2]); - res.accessToken = a; - - return res; + return DB::Session(result[0]); } -std::vector DB::MySQL::listAssets(unsigned int owner) { +std::vector DB::MySQL::listAssets (unsigned int owner) { MYSQL* con = &connection; - Statement st(con, selectSession); + Statement st(con, selectAssets); st.bind(&owner, MYSQL_TYPE_LONG, true); st.execute(); std::vector> res = st.fetchResult(); std::size_t size = res.size(); std::vector result(size); - for (std::size_t i = 0; i < size; ++i) { - const std::vector& proto = res[i]; - DB::Asset& asset = result[i]; - asset.id = std::any_cast(proto[0]); - asset.owner = std::any_cast(proto[1]); - asset.currency = std::any_cast(proto[2]); - asset.title = std::any_cast(proto[3]); - asset.icon = std::any_cast(proto[4]); - asset.archived = std::any_cast(proto[5]); //TODO - } + for (std::size_t i = 0; i < size; ++i) + result[i].parse(res[i]); return result; } +DB::Asset DB::MySQL::addAsset(const Asset& asset) { + MYSQL* con = &connection; + Asset result = asset; + + Statement session(con, insertAsset); + session.bind(&result.owner, MYSQL_TYPE_LONG, true); + session.bind(&result.currency, MYSQL_TYPE_LONG, true); + session.bind(result.title.data(), MYSQL_TYPE_STRING); + session.bind(result.icon.data(), MYSQL_TYPE_STRING); + session.bind(&result.archived, MYSQL_TYPE_TINY); + session.execute(); + + result.id = lastInsertedId(); + + return asset; +} diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index a768b44..e30d309 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -36,6 +36,7 @@ public: Session createSession (const std::string& login, const std::string& access, const std::string& renew) override; Session findSession (const std::string& accessToken) override; std::vector listAssets (unsigned int owner) override; + Asset addAsset (const Asset& asset) override; private: void executeFile (const std::filesystem::path& relativePath); diff --git a/database/schema/CMakeLists.txt b/database/schema/CMakeLists.txt new file mode 100644 index 0000000..b3538da --- /dev/null +++ b/database/schema/CMakeLists.txt @@ -0,0 +1,14 @@ +#SPDX-FileCopyrightText: 2023 Yury Gubich +#SPDX-License-Identifier: GPL-3.0-or-later + +set(HEADERS + session.h + asset.h +) + +set(SOURCES + session.cpp + asset.cpp +) + +target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/database/schema/asset.cpp b/database/schema/asset.cpp new file mode 100644 index 0000000..5d496e3 --- /dev/null +++ b/database/schema/asset.cpp @@ -0,0 +1,44 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#include "asset.h" + +DB::Asset::Asset (): + id(), + owner(), + currency(), + title(), + icon(), + archived() +{} + +DB::Asset::Asset (const std::vector& vec): + id(std::any_cast(vec[0])), + owner(std::any_cast(vec[1])), + currency(std::any_cast(vec[2])), + title(std::any_cast(vec[3])), + icon(std::any_cast(vec[4])), + archived(std::any_cast(vec[5])) +{} + +void DB::Asset::parse (const std::vector& vec) { + id = std::any_cast(vec[0]); + owner = std::any_cast(vec[1]); + currency = std::any_cast(vec[2]); + title = std::any_cast(vec[3]); + icon = std::any_cast(vec[4]); + archived = std::any_cast(vec[5]); +} + +nlohmann::json DB::Asset::toJSON () const { + nlohmann::json result = nlohmann::json::object(); + + result["id"] = id; + //result["owner"] = owner; + //result["currency"] = currency; + result["title"] = title; + result["icon"] = icon; + result["archived"] = archived; + + return result; +} diff --git a/database/schema/asset.h b/database/schema/asset.h new file mode 100644 index 0000000..005f308 --- /dev/null +++ b/database/schema/asset.h @@ -0,0 +1,33 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include +#include +#include +#include + +#include + +namespace DB { +class Asset { +public: + Asset (); + Asset (const std::vector& vec); + + void parse (const std::vector& vec); + nlohmann::json toJSON () const; + +public: + unsigned int id; + unsigned int owner; + unsigned int currency; + std::string title; + std::string icon; + // `color` INTEGER UNSIGNED DEFAULT 0, + // `balance` DECIMAL (20, 5) DEFAULT 0, + // `type` INTEGER UNSIGNED NOT NULL, + bool archived; +}; +} diff --git a/database/schema/session.cpp b/database/schema/session.cpp new file mode 100644 index 0000000..7bedaa5 --- /dev/null +++ b/database/schema/session.cpp @@ -0,0 +1,18 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#include "session.h" + +DB::Session::Session (): + id(), + owner(), + accessToken(), + renewToken() +{} + +DB::Session::Session (const std::vector& vec): + id(std::any_cast(vec[0])), + owner(std::any_cast(vec[1])), + accessToken(std::any_cast(vec[2])), + renewToken(std::any_cast(vec[3])) +{} diff --git a/database/schema/session.h b/database/schema/session.h new file mode 100644 index 0000000..55615c3 --- /dev/null +++ b/database/schema/session.h @@ -0,0 +1,23 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include +#include +#include +#include + +namespace DB { +class Session { +public: + Session (); + Session (const std::vector& vec); + +public: + unsigned int id; + unsigned int owner; + std::string accessToken; + std::string renewToken; +}; +} diff --git a/handler/CMakeLists.txt b/handler/CMakeLists.txt index 2449784..6791d53 100644 --- a/handler/CMakeLists.txt +++ b/handler/CMakeLists.txt @@ -9,6 +9,7 @@ set(HEADERS login.h poll.h listassets.h + addasset.h ) set(SOURCES @@ -19,6 +20,7 @@ set(SOURCES login.cpp poll.cpp listassets.cpp + addasset.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/handler/addasset.cpp b/handler/addasset.cpp new file mode 100644 index 0000000..fd63886 --- /dev/null +++ b/handler/addasset.cpp @@ -0,0 +1,77 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#include "addasset.h" + +#include + +#include "server/server.h" +#include "server/session.h" +#include "database/exceptions.h" + +Handler::AddAsset::AddAsset (std::weak_ptr server): + Handler("addAsset", Request::Method::post), + server(server) +{} + +void Handler::AddAsset::handle (Request& request) { + std::string access = request.getAuthorizationToken(); + if (access.empty()) + return error(request, Response::Status::unauthorized); + + if (access.size() != 32) + return error(request, Response::Status::badRequest); + + std::shared_ptr srv = server.lock(); + if (!srv) + return error(request, Response::Status::internalError); + + std::map form = request.getForm(); + std::map::const_iterator itr = form.find("currency"); + if (itr == form.end()) + return error(request, Response::Status::badRequest); + + DB::Asset asset; + asset.currency = std::stoi(itr->second); + //TODO validate the currency + + itr = form.find("title"); + if (itr == form.end()) + return error(request, Response::Status::badRequest); + + asset.title = itr->second; + + itr = form.find("icon"); + if (itr == form.end()) + return error(request, Response::Status::badRequest); + + asset.icon = itr->second; + + try { + Session& session = srv->getSession(access); + + asset.owner = session.owner; + asset = srv->getDatabase()->addAsset(asset); + + nlohmann::json body = nlohmann::json::object(); + body["asset"] = asset.toJSON(); + + Response& res = request.createResponse(Response::Status::ok); + res.setBody(body); + res.send(); + + } catch (const DB::NoSession& e) { + return error(request, Response::Status::unauthorized); + } catch (const std::exception& e) { + std::cerr << "Exception on poll:\n\t" << e.what() << std::endl; + return error(request, Response::Status::internalError); + } catch (...) { + std::cerr << "Unknown exception on poll" << std::endl; + return error(request, Response::Status::internalError); + } +} + +void Handler::AddAsset::error (Request& request, Response::Status status) { + Response& res = request.createResponse(status); + res.send(); +} diff --git a/handler/addasset.h b/handler/addasset.h new file mode 100644 index 0000000..5ebcbeb --- /dev/null +++ b/handler/addasset.h @@ -0,0 +1,21 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include + +#include "handler.h" + +class Server; +namespace Handler { +class AddAsset : public Handler { +public: + AddAsset (std::weak_ptr server); + virtual void handle (Request& request) override; + + static void error (Request& request, Response::Status status); +private: + std::weak_ptr server; +}; +} diff --git a/handler/listassets.cpp b/handler/listassets.cpp index b8a7ab8..ed0af09 100644 --- a/handler/listassets.cpp +++ b/handler/listassets.cpp @@ -26,7 +26,18 @@ void Handler::ListAssets::handle (Request& request) { try { Session& session = srv->getSession(access); + std::vector assets = srv->getDatabase()->listAssets(session.owner); + nlohmann::json arr = nlohmann::json::array(); + for (const DB::Asset& asset : assets) + arr.push_back(asset.toJSON()); + + nlohmann::json body = nlohmann::json::object(); + body["assets"] = arr; + + Response& res = request.createResponse(Response::Status::ok); + res.setBody(body); + res.send(); } catch (const DB::NoSession& e) { return error(request, Response::Status::unauthorized); diff --git a/server/server.cpp b/server/server.cpp index 089f374..fb93424 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -12,6 +12,8 @@ #include "handler/register.h" #include "handler/login.h" #include "handler/poll.h" +#include "handler/listassets.h" +#include "handler/addasset.h" #include "taskmanager/route.h" @@ -32,7 +34,7 @@ constexpr uint32_t hashMemoryCost = 65536; constexpr uint8_t currentDbVesion = 1; -Server::Server(): +Server::Server (): std::enable_shared_from_this(), terminating(false), requestCount(0), @@ -59,14 +61,16 @@ Server::Server(): db->migrate(currentDbVesion); } -Server::~Server() {} +Server::~Server () {} -void Server::run(int socketDescriptor) { +void Server::run (int socketDescriptor) { router->addRoute(std::make_unique()); router->addRoute(std::make_unique()); router->addRoute(std::make_unique(shared_from_this())); router->addRoute(std::make_unique(shared_from_this())); router->addRoute(std::make_unique(shared_from_this())); + router->addRoute(std::make_unique(shared_from_this())); + router->addRoute(std::make_unique(shared_from_this())); taskManager->start(); scheduler->start(); @@ -82,7 +86,7 @@ void Server::run(int socketDescriptor) { } } -void Server::handleRequest(std::unique_ptr request) { +void Server::handleRequest (std::unique_ptr request) { ++requestCount; if (!serverName) { try { @@ -102,7 +106,7 @@ void Server::handleRequest(std::unique_ptr request) { taskManager->schedule(std::move(route)); } -std::string Server::generateRandomString(std::size_t length) { +std::string Server::generateRandomString (std::size_t length) { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution distribution(0, std::strlen(randomChars) - 1); @@ -114,7 +118,7 @@ std::string Server::generateRandomString(std::size_t length) { return result; } -unsigned int Server::registerAccount(const std::string& login, const std::string& password) { +unsigned int Server::registerAccount (const std::string& login, const std::string& password) { std::size_t encSize = argon2_encodedlen( hashIterations, hashMemoryCost, hashParallel, saltSize, hashSize, Argon2_id @@ -138,7 +142,7 @@ unsigned int Server::registerAccount(const std::string& login, const std::string return db->registerAccount(login, hash); } -bool Server::validatePassword(const std::string& login, const std::string& password) { +bool Server::validatePassword (const std::string& login, const std::string& password) { DB::Resource db = pool->request(); std::string hash = db->getAccountHash(login); @@ -155,7 +159,7 @@ bool Server::validatePassword(const std::string& login, const std::string& passw } } -Session& Server::openSession(const std::string& login) { +Session& Server::openSession (const std::string& login) { std::string accessToken, renewToken; DB::Session s; s.id = 0; @@ -197,3 +201,8 @@ Session& Server::getSession (const std::string& accessToken) { ); return *session.get(); } + + +DB::Resource Server::getDatabase () { + return pool->request(); +} diff --git a/server/server.h b/server/server.h index 4402cfb..eaa9004 100644 --- a/server/server.h +++ b/server/server.h @@ -39,6 +39,7 @@ public: bool validatePassword(const std::string& login, const std::string& password); Session& openSession(const std::string& login); Session& getSession(const std::string& accessToken); + DB::Resource getDatabase(); private: void handleRequest(std::unique_ptr request);