1
0
forked from blue/pica
pica/database/mysql/mysql.cpp

268 lines
7.6 KiB
C++

// SPDX-FileCopyrightText: 2023 Yury Gubich <blue@macaw.me>
// SPDX-License-Identifier: GPL-3.0-or-later
#include "mysql.h"
#include <fstream>
#include <iostream>
#include "mysqld_error.h"
#include "statement.h"
#include "transaction.h"
constexpr const char* versionQuery = "SELECT value FROM system WHERE `key` = 'version'";
constexpr const char* updateQuery = "UPDATE system SET `value` = ? WHERE `key` = 'version'";
constexpr const char* registerQuery = "INSERT INTO accounts (`login`, `type`, `password`) VALUES (?, 1, ?)";
constexpr const char* lastIdQuery = "SELECT LAST_INSERT_ID() AS id";
constexpr const char* assignRoleQuery = "INSERT INTO roleBindings (`account`, `role`) SELECT ?, roles.id FROM roles WHERE roles.name = ?";
static const std::filesystem::path buildSQLPath = "database";
struct ResDeleter {
void operator () (MYSQL_RES* res) {
mysql_free_result(res);
}
};
MySQL::MySQL():
DBInterface(Type::mysql),
connection(),
login(),
password(),
database()
{
mysql_init(&connection);
}
MySQL::~MySQL() {
mysql_close(&connection);
}
void MySQL::connect(const std::string& path) {
if (state != State::disconnected)
return;
MYSQL* con = &connection;
MYSQL* res = mysql_real_connect(
con,
NULL,
login.c_str(),
password.c_str(),
database.empty() ? NULL : database.c_str(),
0,
path.c_str(),
0
);
if (res != con)
throw std::runtime_error(std::string("Error changing connecting: ") + mysql_error(con));
state = State::connected;
}
void MySQL::setCredentials(const std::string& login, const std::string& password) {
if (MySQL::login == login && MySQL::password == password)
return;
MySQL::login = login;
MySQL::password = password;
if (state == State::disconnected)
return;
MYSQL* con = &connection;
int result = mysql_change_user(
con,
login.c_str(),
password.c_str(),
database.empty() ? NULL : database.c_str()
);
if (result != 0)
throw std::runtime_error(std::string("Error changing credetials: ") + mysql_error(con));
}
void MySQL::setDatabase(const std::string& database) {
if (MySQL::database == database)
return;
MySQL::database = database;
if (state == State::disconnected)
return;
MYSQL* con = &connection;
int result = mysql_select_db(con, database.c_str());
if (result != 0)
throw std::runtime_error(std::string("Error changing db: ") + mysql_error(con));
}
void MySQL::disconnect() {
if (state == State::disconnected)
return;
MYSQL* con = &connection;
mysql_close(con);
mysql_init(con); //this is ridiculous!
}
void MySQL::executeFile(const std::filesystem::path& relativePath) {
MYSQL* con = &connection;
std::filesystem::path path = sharedPath() / relativePath;
if (!std::filesystem::exists(path))
throw std::runtime_error("Error executing file "
+ std::filesystem::absolute(path).string()
+ ": file doesn't exist");
std::cout << "Executing file " << path << std::endl;
std::ifstream inputFile(path);
std::string query;
while (std::getline(inputFile, query, ';')) {
std::optional<std::string> comment = getComment(query);
while (comment) {
std::cout << '\t' << comment.value() << std::endl;
comment = getComment(query);
}
if (query.empty())
continue;
int result = mysql_query(con, query.c_str());
if (result != 0) {
int errcode = mysql_errno(con);
if (errcode == ER_EMPTY_QUERY)
continue;
throw std::runtime_error("Error executing file " + path.string() + ": " + mysql_error(con));
}
}
}
uint8_t MySQL::getVersion() {
MYSQL* con = &connection;
int result = mysql_query(con, versionQuery);
if (result != 0) {
unsigned int errcode = mysql_errno(con);
if (errcode == ER_NO_SUCH_TABLE)
return 0;
throw std::runtime_error(std::string("Error executing retreiving version: ") + mysql_error(con));
}
std::unique_ptr<MYSQL_RES, ResDeleter> res(mysql_store_result(con));
if (!res)
throw std::runtime_error(std::string("Querying version returned no result: ") + mysql_error(con));
MYSQL_ROW row = mysql_fetch_row(res.get());
if (row)
return std::stoi(row[0]);
else
return 0;
}
void MySQL::setVersion(uint8_t version) {
std::string strVersion = std::to_string(version);
Statement statement(&connection, updateQuery);
statement.bind(strVersion.data(), MYSQL_TYPE_VAR_STRING);
statement.execute();
}
void MySQL::migrate(uint8_t targetVersion) {
uint8_t currentVersion = getVersion();
while (currentVersion < targetVersion) {
if (currentVersion == 255)
throw std::runtime_error("Maximum possible database version reached");
uint8_t nextVersion = currentVersion + 1;
std::string fileName = "migrations/m" + std::to_string(currentVersion) + ".sql";
std::cout << "Performing migration "
<< std::to_string(currentVersion)
<< " -> "
<< std::to_string(nextVersion)
<< std::endl;
executeFile(fileName);
setVersion(nextVersion);
currentVersion = nextVersion;
}
std::cout << "Database is now on actual version " << std::to_string(targetVersion) << std::endl;
}
std::optional<std::string> MySQL::getComment(std::string& string) {
ltrim(string);
if (string.length() < 2)
return std::nullopt;
if (string[0] == '-') {
if (string[1] == '-') {
string.erase(0, 2);
std::string::size_type eol = string.find('\n');
return extract(string, 0, eol);
}
} else if (string[0] == '/') {
if (string[1] == '*') {
string.erase(0, 2);
std::string::size_type end = 0;
do {
end = string.find(end, '*');
} while (end != std::string::npos && end < string.size() - 1 && string[end + 1] == '/');
if (end < string.size() - 1)
end = std::string::npos;
return extract(string, 0, end);
}
}
return std::nullopt;
}
unsigned int MySQL::registerAccount(const std::string& login, const std::string& hash) {
//TODO validate filed lengths!
MYSQL* con = &connection;
MySQL::Transaction txn(con);
Statement addAcc(con, registerQuery);
std::string l = login; //I hate copying just to please this horible API
std::string h = hash;
addAcc.bind(l.data(), MYSQL_TYPE_STRING);
addAcc.bind(h.data(), MYSQL_TYPE_STRING);
addAcc.execute();
unsigned int id = lastInsertedId();
static std::string defaultRole("default");
Statement addRole(con, assignRoleQuery);
addRole.bind(&id, MYSQL_TYPE_LONG, true);
addRole.bind(defaultRole.data(), MYSQL_TYPE_STRING);
addRole.execute();
txn.commit();
return id;
}
unsigned int MySQL::lastInsertedId() {
MYSQL* con = &connection;
int result = mysql_query(con, lastIdQuery);
if (result != 0)
throw std::runtime_error(std::string("Error executing last inserted id: ") + mysql_error(con));
std::unique_ptr<MYSQL_RES, ResDeleter> res(mysql_store_result(con));
if (!res)
throw std::runtime_error(std::string("Querying last inserted id returned no result: ") + mysql_error(con));
MYSQL_ROW row = mysql_fetch_row(res.get());
if (row)
return std::stoi(row[0]);
else
throw std::runtime_error(std::string("Querying last inserted id returned no rows"));
}