1
0
forked from blue/pica

scheduler canceling, sessiion query, didn't test yet!

This commit is contained in:
Blue 2024-01-03 19:20:01 -03:00
parent 544db92b6e
commit 5d765958e5
Signed by untrusted user: blue
GPG Key ID: 9B203B252A63EE38
15 changed files with 166 additions and 21 deletions

View File

@ -18,3 +18,7 @@ DB::EmptyResult::EmptyResult(const std::string& text):
DB::NoLogin::NoLogin(const std::string& text): DB::NoLogin::NoLogin(const std::string& text):
EmptyResult(text) EmptyResult(text)
{} {}
DB::NoSession::NoSession (const std::string& text):
EmptyResult(text)
{}

View File

@ -26,4 +26,9 @@ class NoLogin : public EmptyResult {
public: public:
explicit NoLogin(const std::string& text); explicit NoLogin(const std::string& text);
}; };
class NoSession : public EmptyResult {
public:
explicit NoSession(const std::string& text);
};
} }

View File

@ -9,6 +9,13 @@
#include <stdint.h> #include <stdint.h>
namespace DB { namespace DB {
struct Session {
unsigned int id;
unsigned int owner;
std::string accessToken;
std::string renewToken;
};
class Interface { class Interface {
public: public:
enum class Type { enum class Type {
@ -40,6 +47,7 @@ public:
virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0; virtual unsigned int registerAccount(const std::string& login, const std::string& hash) = 0;
virtual std::string getAccountHash(const std::string& login) = 0; virtual std::string getAccountHash(const std::string& login) = 0;
virtual unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) = 0; virtual unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) = 0;
virtual Session findSession(const std::string& accessToken) = 0;
protected: protected:
Interface(Type type); Interface(Type type);

View File

@ -20,6 +20,7 @@ constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `r
constexpr const char* selectHash = "SELECT password FROM accounts where login = ?"; constexpr const char* selectHash = "SELECT password FROM accounts where login = ?";
constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `access`, `renew`, `persist`, `device`)" constexpr const char* createSessionQuery = "INSERT INTO sessions (`owner`, `access`, `renew`, `persist`, `device`)"
" SELECT accounts.id, ?, ?, true, ? FROM accounts WHERE accounts.login = ?"; " SELECT accounts.id, ?, ?, true, ? FROM accounts WHERE accounts.login = ?";
constexpr const char* selectSession = "SELECT id, owner, renew FROM sessions where access = ?";
static const std::filesystem::path buildSQLPath = "database"; static const std::filesystem::path buildSQLPath = "database";
@ -298,6 +299,24 @@ unsigned int DB::MySQL::lastInsertedId() {
else else
throw std::runtime_error(std::string("Querying last inserted id returned no rows")); throw std::runtime_error(std::string("Querying last inserted id returned no rows"));
} }
DB::Session DB::MySQL::findSession(const std::string& accessToken) {
std::string a = accessToken;
MYSQL* con = &connection;
Statement session(con, selectSession);
session.bind(a.data(), MYSQL_TYPE_STRING);
std::vector<std::vector<std::any>> result = session.fetchResult();
if (result.empty())
throw NoSession("Couldn't find session with token " + a);
DB::Session res;
res.id = std::any_cast<unsigned int>(result[0][0]);
res.owner = std::any_cast<unsigned int>(result[0][1]);
res.renewToken = std::any_cast<const std::string&>(result[0][2]);
res.accessToken = a;
return res;
}

View File

@ -34,6 +34,7 @@ public:
unsigned int registerAccount(const std::string& login, const std::string& hash) override; unsigned int registerAccount(const std::string& login, const std::string& hash) override;
std::string getAccountHash(const std::string& login) override; std::string getAccountHash(const std::string& login) override;
unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) override; unsigned int createSession(const std::string& login, const std::string& access, const std::string& renew) override;
Session findSession(const std::string& accessToken) override;
private: private:
void executeFile(const std::filesystem::path& relativePath); void executeFile(const std::filesystem::path& relativePath);

View File

@ -7,6 +7,7 @@ set(HEADERS
env.h env.h
register.h register.h
login.h login.h
poll.h
) )
set(SOURCES set(SOURCES
@ -15,6 +16,7 @@ set(SOURCES
env.cpp env.cpp
register.cpp register.cpp
login.cpp login.cpp
poll.cpp
) )
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})

View File

@ -6,6 +6,7 @@
#include "response/response.h" #include "response/response.h"
#include "server/server.h" #include "server/server.h"
#include "request/redirect.h" #include "request/redirect.h"
#include "database/exceptions.h"
Handler::Poll::Poll (Server* server): Handler::Poll::Poll (Server* server):
Handler("login", Request::Method::get), Handler("login", Request::Method::get),
@ -25,6 +26,8 @@ void Handler::Poll::handle (Request& request) {
throw Redirect(&session); throw Redirect(&session);
} catch (const Redirect& r) { } catch (const Redirect& r) {
throw r; throw r;
} catch (const DB::NoSession& e) {
return error(request, Result::tokenProblem, Response::Status::unauthorized);
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "Exception on poll:\n\t" << e.what() << std::endl; std::cerr << "Exception on poll:\n\t" << e.what() << std::endl;
return error(request, Result::unknownError, Response::Status::internalError); return error(request, Result::unknownError, Response::Status::internalError);

View File

@ -192,6 +192,22 @@ Session& Server::openSession(const std::string& login) {
if (sessionId == 0) if (sessionId == 0)
throw std::runtime_error("Couldn't create session, ran out of attempts"); throw std::runtime_error("Couldn't create session, ran out of attempts");
std::unique_ptr<Session>& session = sessions[accessToken] = std::make_unique<Session>(sessionId, accessToken, renewToken); std::unique_ptr<Session>& session = sessions[accessToken] = std::make_unique<Session>(scheduler, sessionId, accessToken, renewToken);
return *session.get();
}
Session& Server::getSession (const std::string& accessToken) {
Sessions::const_iterator itr = sessions.find(accessToken);
if (itr != sessions.end())
return *(itr->second);
DB::Resource db = pool->request();
DB::Session s = db->findSession(accessToken);
std::unique_ptr<Session>& session = sessions[accessToken] = std::make_unique<Session>(
scheduler,
s.id,
s.accessToken,
s.renewToken
);
return *session.get(); return *session.get();
} }

View File

@ -15,9 +15,17 @@ Session::Session(
id(id), id(id),
access(access), access(access),
renew(renew), renew(renew),
polling(nullptr) polling(nullptr),
timeoutId(TM::Scheduler::none)
{} {}
Session::~Session () {
if (timeoutId != TM::Scheduler::none) {
if (std::shared_ptr<TM::Scheduler> sch = scheduler.lock())
sch->cancel(timeoutId);
}
}
std::string Session::getAccessToken() const { std::string Session::getAccessToken() const {
return access; return access;
} }
@ -27,18 +35,30 @@ std::string Session::getRenewToken() const {
} }
void Session::accept(std::unique_ptr<Request> request) { void Session::accept(std::unique_ptr<Request> request) {
std::shared_ptr<TM::Scheduler> sch = scheduler.lock();
if (polling) { if (polling) {
Handler::Poll::error(*request.get(), Handler::Poll::Result::replace, Response::Status::ok); Handler::Poll::error(*request.get(), Handler::Poll::Result::replace, Response::Status::ok);
if (timeoutId != TM::Scheduler::none) {
if (sch)
sch->cancel(timeoutId);
timeoutId = TM::Scheduler::none;
}
//TODO unschedule //TODO unschedule
} }
std::shared_ptr<TM::Scheduler> sch = scheduler.lock();
if (!sch) { if (!sch) {
std::cerr << "Was unable to schedule polling timeout, replying with an error" << std::endl; std::cerr << "Was unable to schedule polling timeout, replying with an error" << std::endl;
Handler::Poll::error(*request.get(), Handler::Poll::Result::unknownError, Response::Status::internalError); Handler::Poll::error(*request.get(), Handler::Poll::Result::unknownError, Response::Status::internalError);
return; return;
} }
sch->schedule(std::bind(&Session::onTimeout, this), TM::Scheduler::Delay(5000)); timeoutId = sch->schedule(std::bind(&Session::onTimeout, this), TM::Scheduler::Delay(5000));
polling = std::move(request); polling = std::move(request);
} }
void Session::onTimeout () {
timeoutId = TM::Scheduler::none;
Handler::Poll::error(*polling.get(), Handler::Poll::Result::timeout, Response::Status::ok);
polling.reset();
}

View File

@ -18,6 +18,7 @@ public:
); );
Session(const Session&) = delete; Session(const Session&) = delete;
Session(Session&& other); Session(Session&& other);
~Session();
Session& operator = (const Session&) = delete; Session& operator = (const Session&) = delete;
Session& operator = (Session&& other); Session& operator = (Session&& other);
@ -34,4 +35,5 @@ private:
std::string access; std::string access;
std::string renew; std::string renew;
std::unique_ptr<Request> polling; std::unique_ptr<Request> polling;
TM::Record::ID timeoutId;
}; };

View File

@ -7,6 +7,7 @@ set(HEADERS
route.h route.h
scheduler.h scheduler.h
function.h function.h
record.h
) )
set(SOURCES set(SOURCES
@ -15,6 +16,7 @@ set(SOURCES
route.cpp route.cpp
scheduler.cpp scheduler.cpp
function.cpp function.cpp
record.cpp
) )
target_sources(${PROJECT_NAME} PRIVATE ${SOURCES}) target_sources(${PROJECT_NAME} PRIVATE ${SOURCES})

18
taskmanager/record.cpp Normal file
View File

@ -0,0 +1,18 @@
//SPDX-FileCopyrightText: 2024 Yury Gubich <blue@macaw.me>
//SPDX-License-Identifier: GPL-3.0-or-later
#include "record.h"
TM::Record::Record (ID id, const Task& task, Time time):
id(id),
task(task),
time(time)
{}
bool TM::Record::operator < (const Record& other) const {
return time < other.time;
}
bool TM::Record::operator > (const Record& other) const {
return time > other.time;
}

26
taskmanager/record.h Normal file
View File

@ -0,0 +1,26 @@
//SPDX-FileCopyrightText: 2024 Yury Gubich <blue@macaw.me>
//SPDX-License-Identifier: GPL-3.0-or-later
#pragma once
#include <cstdint>
#include <chrono>
#include <functional>
namespace TM {
class Record {
public:
using Time = std::chrono::time_point<std::chrono::steady_clock>;
using Task = std::function<void()>;
using ID = uint64_t;
Record(ID id, const Task& task, Time time);
ID id;
Task task;
Time time;
bool operator > (const Record& other) const;
bool operator < (const Record& other) const;
};
}

View File

@ -3,13 +3,17 @@
#include "scheduler.h" #include "scheduler.h"
const TM::Record::ID TM::Scheduler::none = 0;
TM::Scheduler::Scheduler (std::weak_ptr<Manager> manager): TM::Scheduler::Scheduler (std::weak_ptr<Manager> manager):
queue(), queue(),
scheduled(),
manager(manager), manager(manager),
mutex(), mutex(),
cond(), cond(),
thread(nullptr), thread(nullptr),
running(false) running(false),
idCounter(TM::Scheduler::none)
{} {}
TM::Scheduler::~Scheduler () { TM::Scheduler::~Scheduler () {
@ -50,28 +54,39 @@ void TM::Scheduler::loop () {
Time currentTime = std::chrono::steady_clock::now(); Time currentTime = std::chrono::steady_clock::now();
while (!queue.empty()) { while (!queue.empty()) {
Time nextScheduledTime = queue.top().first; Time nextScheduledTime = queue.top().time;
if (nextScheduledTime > currentTime) { if (nextScheduledTime > currentTime) {
cond.wait_until(lock, nextScheduledTime); cond.wait_until(lock, nextScheduledTime);
break; break;
} }
Record task = queue.pop(); Record record = queue.pop();
std::size_t count = scheduled.erase(record.id);
if (count == 0) //it means this record has been cancelled, no need to execute it
continue;
lock.unlock(); lock.unlock();
std::shared_ptr<Manager> mngr = manager.lock(); std::shared_ptr<Manager> mngr = manager.lock();
if (mngr) if (mngr)
mngr->schedule(std::move(task.second)); mngr->schedule(std::make_unique<Function>(record.task));
lock.lock(); lock.lock();
} }
} }
} }
void TM::Scheduler::schedule (const std::function<void()>& task, Delay delay) { TM::Record::ID TM::Scheduler::schedule (const Task& task, Delay delay) {
std::unique_lock lock(mutex); std::unique_lock lock(mutex);
Time time = std::chrono::steady_clock::now() + delay; Time time = std::chrono::steady_clock::now() + delay;
queue.emplace(time, std::make_unique<Function>(task)); queue.emplace(++idCounter, task, time);
scheduled.emplace(idCounter);
lock.unlock(); lock.unlock();
cond.notify_one(); cond.notify_one();
return idCounter;
} }
bool TM::Scheduler::cancel (Record::ID id) {
return scheduled.erase(id) != 0; //not to mess with the queue, here we just mark it as not scheduled
} //and when the time comes it will be just discarded

View File

@ -6,43 +6,47 @@
#include <memory> #include <memory>
#include <thread> #include <thread>
#include <chrono> #include <chrono>
#include <functional>
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include <set>
#include "manager.h" #include "manager.h"
#include "function.h" #include "function.h"
#include "record.h"
#include "utils/helpers.h" #include "utils/helpers.h"
namespace TM { namespace TM {
class Scheduler { class Scheduler {
public: public:
using Delay = std::chrono::milliseconds; using Delay = std::chrono::milliseconds;
using Time = Record::Time;
using Task = Record::Task;
Scheduler (std::weak_ptr<Manager> manager); Scheduler (std::weak_ptr<Manager> manager);
~Scheduler (); ~Scheduler ();
void start(); void start ();
void stop(); void stop ();
void schedule(const std::function<void()>& task, Delay delay); Record::ID schedule (const Task& task, Delay delay);
bool cancel (Record::ID id);
static const Record::ID none;
private: private:
void loop(); void loop ();
private: private:
using Task = std::unique_ptr<Function>;
using Time = std::chrono::time_point<std::chrono::steady_clock>;
using Record = std::pair<Time, Task>;
PriorityQueue< PriorityQueue<
Record, Record,
std::vector<Record>, std::vector<Record>,
FirstGreater<Record> std::greater<>
> queue; > queue;
std::set<Record::ID> scheduled;
std::weak_ptr<Manager> manager; std::weak_ptr<Manager> manager;
std::mutex mutex; std::mutex mutex;
std::condition_variable cond; std::condition_variable cond;
std::unique_ptr<std::thread> thread; std::unique_ptr<std::thread> thread;
bool running; bool running;
Record::ID idCounter;
}; };
} }