diff --git a/database/exceptions.cpp b/database/exceptions.cpp index 5ba8058..b48ef51 100644 --- a/database/exceptions.cpp +++ b/database/exceptions.cpp @@ -18,3 +18,7 @@ DB::EmptyResult::EmptyResult(const std::string& text): DB::NoLogin::NoLogin(const std::string& text): EmptyResult(text) {} + +DB::NoSession::NoSession (const std::string& text): + EmptyResult(text) +{} diff --git a/database/exceptions.h b/database/exceptions.h index 19da281..81b4f7b 100644 --- a/database/exceptions.h +++ b/database/exceptions.h @@ -26,4 +26,9 @@ class NoLogin : public EmptyResult { public: explicit NoLogin(const std::string& text); }; + +class NoSession : public EmptyResult { +public: + explicit NoSession(const std::string& text); +}; } diff --git a/database/interface.h b/database/interface.h index 7b6e204..4d0e651 100644 --- a/database/interface.h +++ b/database/interface.h @@ -9,6 +9,13 @@ #include namespace DB { +struct Session { + unsigned int id; + unsigned int owner; + std::string accessToken; + std::string renewToken; +}; + class Interface { public: enum class Type { @@ -40,6 +47,7 @@ public: virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; virtual std::string getAccountHash(const std::string& login) = 0; virtual unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) = 0; + virtual Session findSession(const std::string& accessToken) = 0; protected: Interface(Type type); diff --git a/database/mysql/mysql.cpp b/database/mysql/mysql.cpp index 151fa9b..33f2f75 100644 --- a/database/mysql/mysql.cpp +++ b/database/mysql/mysql.cpp @@ -20,6 +20,7 @@ constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `r constexpr const char* selectHash = "SELECT password FROM accounts where login = ?"; constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `access`, `renew`, `persist`, `device`)" " SELECT accounts.id, ?, ?, true, ? FROM accounts WHERE accounts.login = ?"; +constexpr const char* selectSession = "SELECT id, owner, renew FROM sessions where access = ?"; static const std::filesystem::path buildSQLPath = "database"; @@ -298,6 +299,24 @@ unsigned int DB::MySQL::lastInsertedId() { else throw std::runtime_error(std::string("Querying last inserted id returned no rows")); } - +DB::Session DB::MySQL::findSession(const std::string& accessToken) { + std::string a = accessToken; + MYSQL* con = &connection; + + Statement session(con, selectSession); + session.bind(a.data(), MYSQL_TYPE_STRING); + + std::vector> result = session.fetchResult(); + if (result.empty()) + throw NoSession("Couldn't find session with token " + a); + + DB::Session res; + res.id = std::any_cast(result[0][0]); + res.owner = std::any_cast(result[0][1]); + res.renewToken = std::any_cast(result[0][2]); + res.accessToken = a; + + return res; +} diff --git a/database/mysql/mysql.h b/database/mysql/mysql.h index 7bd9775..517ed1c 100644 --- a/database/mysql/mysql.h +++ b/database/mysql/mysql.h @@ -34,6 +34,7 @@ public: unsigned int registerAccount(const std::string& login, const std::string& hash) override; std::string getAccountHash(const std::string& login) override; unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) override; + Session findSession(const std::string& accessToken) override; private: void executeFile(const std::filesystem::path& relativePath); diff --git a/handler/CMakeLists.txt b/handler/CMakeLists.txt index 536d1d0..e6c3995 100644 --- a/handler/CMakeLists.txt +++ b/handler/CMakeLists.txt @@ -7,6 +7,7 @@ set(HEADERS env.h register.h login.h + poll.h ) set(SOURCES @@ -15,6 +16,7 @@ set(SOURCES env.cpp register.cpp login.cpp + poll.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/handler/poll.cpp b/handler/poll.cpp index c0e4d82..0ffd138 100644 --- a/handler/poll.cpp +++ b/handler/poll.cpp @@ -6,6 +6,7 @@ #include "response/response.h" #include "server/server.h" #include "request/redirect.h" +#include "database/exceptions.h" Handler::Poll::Poll (Server* server): Handler("login", Request::Method::get), @@ -25,6 +26,8 @@ void Handler::Poll::handle (Request& request) { throw Redirect(&session); } catch (const Redirect& r) { throw r; + } catch (const DB::NoSession& e) { + return error(request, Result::tokenProblem, Response::Status::unauthorized); } catch (const std::exception& e) { std::cerr << "Exception on poll:\n\t" << e.what() << std::endl; return error(request, Result::unknownError, Response::Status::internalError); diff --git a/server/server.cpp b/server/server.cpp index c4a2f60..bc258ed 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -192,6 +192,22 @@ Session& Server::openSession(const std::string& login) { if (sessionId == 0) throw std::runtime_error("Couldn't create session, ran out of attempts"); - std::unique_ptr& session = sessions[accessToken] = std::make_unique(sessionId, accessToken, renewToken); + std::unique_ptr& session = sessions[accessToken] = std::make_unique(scheduler, sessionId, accessToken, renewToken); + return *session.get(); +} + +Session& Server::getSession (const std::string& accessToken) { + Sessions::const_iterator itr = sessions.find(accessToken); + if (itr != sessions.end()) + return *(itr->second); + + DB::Resource db = pool->request(); + DB::Session s = db->findSession(accessToken); + std::unique_ptr& session = sessions[accessToken] = std::make_unique( + scheduler, + s.id, + s.accessToken, + s.renewToken + ); return *session.get(); } diff --git a/server/session.cpp b/server/session.cpp index efafd89..6a1aeda 100644 --- a/server/session.cpp +++ b/server/session.cpp @@ -15,9 +15,17 @@ Session::Session( id(id), access(access), renew(renew), - polling(nullptr) + polling(nullptr), + timeoutId(TM::Scheduler::none) {} +Session::~Session () { + if (timeoutId != TM::Scheduler::none) { + if (std::shared_ptr sch = scheduler.lock()) + sch->cancel(timeoutId); + } +} + std::string Session::getAccessToken() const { return access; } @@ -27,18 +35,30 @@ std::string Session::getRenewToken() const { } void Session::accept(std::unique_ptr request) { + std::shared_ptr sch = scheduler.lock(); if (polling) { Handler::Poll::error(*request.get(), Handler::Poll::Result::replace, Response::Status::ok); + if (timeoutId != TM::Scheduler::none) { + if (sch) + sch->cancel(timeoutId); + + timeoutId = TM::Scheduler::none; + } //TODO unschedule } - std::shared_ptr sch = scheduler.lock(); if (!sch) { std::cerr << "Was unable to schedule polling timeout, replying with an error" << std::endl; Handler::Poll::error(*request.get(), Handler::Poll::Result::unknownError, Response::Status::internalError); return; } - sch->schedule(std::bind(&Session::onTimeout, this), TM::Scheduler::Delay(5000)); + timeoutId = sch->schedule(std::bind(&Session::onTimeout, this), TM::Scheduler::Delay(5000)); polling = std::move(request); } + +void Session::onTimeout () { + timeoutId = TM::Scheduler::none; + Handler::Poll::error(*polling.get(), Handler::Poll::Result::timeout, Response::Status::ok); + polling.reset(); +} diff --git a/server/session.h b/server/session.h index 14c43e9..7581352 100644 --- a/server/session.h +++ b/server/session.h @@ -18,6 +18,7 @@ public: ); Session(const Session&) = delete; Session(Session&& other); + ~Session(); Session& operator = (const Session&) = delete; Session& operator = (Session&& other); @@ -34,4 +35,5 @@ private: std::string access; std::string renew; std::unique_ptr polling; + TM::Record::ID timeoutId; }; diff --git a/taskmanager/CMakeLists.txt b/taskmanager/CMakeLists.txt index f45b98f..70a8210 100644 --- a/taskmanager/CMakeLists.txt +++ b/taskmanager/CMakeLists.txt @@ -7,6 +7,7 @@ set(HEADERS route.h scheduler.h function.h + record.h ) set(SOURCES @@ -15,6 +16,7 @@ set(SOURCES route.cpp scheduler.cpp function.cpp + record.cpp ) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) diff --git a/taskmanager/record.cpp b/taskmanager/record.cpp new file mode 100644 index 0000000..77c8850 --- /dev/null +++ b/taskmanager/record.cpp @@ -0,0 +1,18 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#include "record.h" + +TM::Record::Record (ID id, const Task& task, Time time): + id(id), + task(task), + time(time) +{} + +bool TM::Record::operator < (const Record& other) const { + return time < other.time; +} + +bool TM::Record::operator > (const Record& other) const { + return time > other.time; +} diff --git a/taskmanager/record.h b/taskmanager/record.h new file mode 100644 index 0000000..fe2b7c5 --- /dev/null +++ b/taskmanager/record.h @@ -0,0 +1,26 @@ +//SPDX-FileCopyrightText: 2024 Yury Gubich +//SPDX-License-Identifier: GPL-3.0-or-later + +#pragma once + +#include +#include +#include + +namespace TM { +class Record { +public: + using Time = std::chrono::time_point; + using Task = std::function; + using ID = uint64_t; + + Record(ID id, const Task& task, Time time); + + ID id; + Task task; + Time time; + + bool operator > (const Record& other) const; + bool operator < (const Record& other) const; +}; +} diff --git a/taskmanager/scheduler.cpp b/taskmanager/scheduler.cpp index db781e3..fcae841 100644 --- a/taskmanager/scheduler.cpp +++ b/taskmanager/scheduler.cpp @@ -3,13 +3,17 @@ #include "scheduler.h" +const TM::Record::ID TM::Scheduler::none = 0; + TM::Scheduler::Scheduler (std::weak_ptr manager): queue(), + scheduled(), manager(manager), mutex(), cond(), thread(nullptr), - running(false) + running(false), + idCounter(TM::Scheduler::none) {} TM::Scheduler::~Scheduler () { @@ -50,28 +54,39 @@ void TM::Scheduler::loop () { Time currentTime = std::chrono::steady_clock::now(); while (!queue.empty()) { - Time nextScheduledTime = queue.top().first; + Time nextScheduledTime = queue.top().time; if (nextScheduledTime > currentTime) { cond.wait_until(lock, nextScheduledTime); break; } - Record task = queue.pop(); + Record record = queue.pop(); + std::size_t count = scheduled.erase(record.id); + if (count == 0) //it means this record has been cancelled, no need to execute it + continue; + lock.unlock(); std::shared_ptr mngr = manager.lock(); if (mngr) - mngr->schedule(std::move(task.second)); + mngr->schedule(std::make_unique(record.task)); lock.lock(); } } } -void TM::Scheduler::schedule (const std::function& task, Delay delay) { +TM::Record::ID TM::Scheduler::schedule (const Task& task, Delay delay) { std::unique_lock lock(mutex); Time time = std::chrono::steady_clock::now() + delay; - queue.emplace(time, std::make_unique(task)); + queue.emplace(++idCounter, task, time); + scheduled.emplace(idCounter); lock.unlock(); cond.notify_one(); + + return idCounter; } + +bool TM::Scheduler::cancel (Record::ID id) { + return scheduled.erase(id) != 0; //not to mess with the queue, here we just mark it as not scheduled +} //and when the time comes it will be just discarded diff --git a/taskmanager/scheduler.h b/taskmanager/scheduler.h index a952858..6b3e609 100644 --- a/taskmanager/scheduler.h +++ b/taskmanager/scheduler.h @@ -6,43 +6,47 @@ #include #include #include -#include #include #include +#include #include "manager.h" #include "function.h" +#include "record.h" #include "utils/helpers.h" namespace TM { class Scheduler { public: using Delay = std::chrono::milliseconds; + using Time = Record::Time; + using Task = Record::Task; Scheduler (std::weak_ptr manager); ~Scheduler (); - void start(); - void stop(); - void schedule(const std::function& task, Delay delay); + void start (); + void stop (); + Record::ID schedule (const Task& task, Delay delay); + bool cancel (Record::ID id); + + static const Record::ID none; private: - void loop(); + void loop (); private: - using Task = std::unique_ptr; - using Time = std::chrono::time_point; - using Record = std::pair; - PriorityQueue< Record, std::vector, - FirstGreater + std::greater<> > queue; + std::set scheduled; std::weak_ptr manager; std::mutex mutex; std::condition_variable cond; std::unique_ptr thread; bool running; + Record::ID idCounter; }; }