diff --git a/server/rieterd.toml b/server/rieterd.toml index 781a055..9cc56bf 100644 --- a/server/rieterd.toml +++ b/server/rieterd.toml @@ -1,11 +1,17 @@ api_key = "test" -port = 8000 -log_level = "tower_http=debug,rieterd=debug" +pkg_workers = 2 +log_level = "rieterd=debug" [fs] -type = "locl" +type = "local" data_dir = "./data" [db] type = "sqlite" db_dir = "./data" +# [db] +# type = "postgres" +# host = "localhost" +# db = "rieter" +# user = "rieter" +# password = "rieter" diff --git a/server/src/cli.rs b/server/src/cli.rs index 1ceaf27..73dc9f2 100644 --- a/server/src/cli.rs +++ b/server/src/cli.rs @@ -1,4 +1,4 @@ -use crate::{distro::MetaDistroMgr, Config, Global}; +use crate::{distro::MetaDistroMgr, Config, FsConfig, Global}; use std::{io, path::PathBuf, sync::Arc}; @@ -12,13 +12,6 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[derive(Parser)] #[command(author, version, about, long_about = None)] pub struct Cli { - /// Directory where repository metadata & SQLite database is stored - #[arg(env = "RIETER_DATA_DIR")] - pub data_dir: PathBuf, - /// API key to authenticate private routes with - #[arg(env = "RIETER_API_KEY")] - pub api_key: String, - #[arg( short, long, @@ -26,89 +19,54 @@ pub struct Cli { default_value = "./rieterd.toml" )] pub config_file: PathBuf, - - /// Database connection URL; either sqlite:// or postgres://. Defaults to rieter.sqlite in the - /// data directory - #[arg(short, long, env = "RIETER_DATABASE_URL")] - pub database_url: Option, - /// Port the server will listen on - #[arg( - short, - long, - value_name = "PORT", - default_value_t = 8000, - env = "RIETER_PORT" - )] - pub port: u16, - /// Log levels for the tracing - #[arg( - long, - value_name = "LOG_LEVEL", - default_value = "tower_http=debug,rieterd=debug,sea_orm=debug", - env = "RIETER_LOG" - )] - pub log: String, } impl Cli { - pub fn init_tracing(&self) { + pub async fn run(&self) -> crate::Result<()> { + let config: Config = Config::figment(&self.config_file) + .extract() + .inspect_err(|e| tracing::error!("{}", e))?; + tracing_subscriber::registry() - .with(tracing_subscriber::EnvFilter::new(self.log.clone())) + .with(tracing_subscriber::EnvFilter::new(config.log_level.clone())) .with(tracing_subscriber::fmt::layer()) .init(); - } - pub async fn run(&self) -> crate::Result<()> { - self.init_tracing(); + tracing::info!("Connecting to database"); + let db = crate::db::connect(&config.db).await?; - //tracing::debug!("{:?}", &self.config_file); - //let new_config: crate::config::Config = crate::config::Config::figment(&self.config_file).extract().inspect_err( - // |e| tracing::error!("{}", e) - //)?; - //tracing::debug!("{:?}", new_config); - - let db_url = if let Some(url) = &self.database_url { - url.clone() - } else { - format!( - "sqlite://{}?mode=rwc", - self.data_dir.join("rieter.sqlite").to_string_lossy() - ) - }; - - debug!("Connecting to database with URL {}", db_url); - - let mut options = sea_orm::ConnectOptions::new(db_url); - options.max_connections(16); - - let db = sea_orm::Database::connect(options).await?; crate::db::Migrator::up(&db, None).await?; - debug!("Successfully applied migrations"); - - let config = Config { - data_dir: self.data_dir.clone(), + let mgr = match &config.fs { + FsConfig::Local { data_dir } => { + crate::repo::RepoMgr::new(data_dir.join("repos"), db.clone()).await? + } }; - let mgr = - Arc::new(crate::repo::RepoMgr::new(&self.data_dir.join("repos"), db.clone()).await?); + let mgr = Arc::new(mgr); - for _ in 0..1 { + for _ in 0..config.pkg_workers { let clone = Arc::clone(&mgr); tokio::spawn(async move { clone.pkg_parse_task().await }); } - let global = Global { config, mgr, db }; + let global = Global { + config: config.clone(), + mgr, + db, + }; // build our application with a single route let app = Router::new() .nest("/api", crate::api::router()) - .merge(crate::repo::router(&self.api_key)) + .merge(crate::repo::router(&config.api_key)) .with_state(global) .layer(TraceLayer::new_for_http()); - let domain: String = format!("0.0.0.0:{}", self.port).parse().unwrap(); + let domain: String = format!("{}:{}", config.domain, config.port) + .parse() + .unwrap(); let listener = tokio::net::TcpListener::bind(domain).await?; // run it with hyper on localhost:3000 Ok(axum::serve(listener, app.into_make_service()) diff --git a/server/src/config.rs b/server/src/config.rs index a639362..e165fdc 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -6,34 +6,49 @@ use figment::{ }; use serde::Deserialize; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] #[serde(rename_all = "lowercase")] #[serde(tag = "type")] pub enum FsConfig { Local { data_dir: PathBuf }, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] #[serde(rename_all = "lowercase")] #[serde(tag = "type")] pub enum DbConfig { Sqlite { db_dir: PathBuf, + #[serde(default = "default_db_sqlite_max_connections")] + max_connections: u32, }, Postgres { host: String, + #[serde(default = "default_db_postgres_port")] + port: u16, user: String, password: String, + db: String, + #[serde(default)] + schema: String, + #[serde(default = "default_db_postgres_max_connections")] + max_connections: u32, }, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct Config { - api_key: String, - port: u16, - log_level: String, - fs: FsConfig, - db: DbConfig, + pub api_key: String, + #[serde(default = "default_domain")] + pub domain: String, + #[serde(default = "default_port")] + pub port: u16, + #[serde(default = "default_log_level")] + pub log_level: String, + pub fs: FsConfig, + pub db: DbConfig, + #[serde(default = "default_pkg_workers")] + pub pkg_workers: u32, } impl Config { @@ -43,3 +58,31 @@ impl Config { .merge(Env::prefixed("RIETER_")) } } + +fn default_domain() -> String { + String::from("0.0.0.0") +} + +fn default_port() -> u16 { + 8000 +} + +fn default_log_level() -> String { + String::from("tower_http=debug,rieterd=debug,sea_orm=debug") +} + +fn default_db_sqlite_max_connections() -> u32 { + 16 +} + +fn default_db_postgres_port() -> u16 { + 5432 +} + +fn default_db_postgres_max_connections() -> u32 { + 16 +} + +fn default_pkg_workers() -> u32 { + 1 +} diff --git a/server/src/db/mod.rs b/server/src/db/mod.rs index 98f42a4..a1b7476 100644 --- a/server/src/db/mod.rs +++ b/server/src/db/mod.rs @@ -2,10 +2,12 @@ pub mod entities; mod migrator; pub mod query; +use crate::config::DbConfig; + pub use entities::{prelude::*, *}; pub use migrator::Migrator; -use sea_orm::{DeriveActiveEnum, EnumIter}; +use sea_orm::{ConnectionTrait, Database, DbConn, DeriveActiveEnum, EnumIter}; use serde::{Deserialize, Serialize}; type Result = std::result::Result; @@ -50,3 +52,50 @@ pub struct FullPackage { related: Vec<(PackageRelatedEnum, String)>, files: Vec, } + +pub async fn connect(conn: &DbConfig) -> crate::Result { + match conn { + DbConfig::Sqlite { + db_dir, + max_connections, + } => { + let url = format!( + "sqlite://{}?mode=rwc", + db_dir.join("rieter.sqlite").to_string_lossy() + ); + let options = sea_orm::ConnectOptions::new(url) + .max_connections(*max_connections) + .to_owned(); + + let conn = Database::connect(options).await?; + + // synchronous=NORMAL still ensures database consistency with WAL mode, as per the docs + // https://www.sqlite.org/pragma.html#pragma_synchronous + conn.execute_unprepared("PRAGMA journal_mode=WAL;").await?; + conn.execute_unprepared("PRAGMA synchronous=NORMAL;") + .await?; + + Ok(conn) + } + DbConfig::Postgres { + host, + port, + db, + user, + password, + schema, + max_connections, + } => { + let mut url = format!("postgres://{}:{}@{}:{}/{}", user, password, host, port, db); + + if schema != "" { + url = format!("{url}?currentSchema={schema}"); + } + + let options = sea_orm::ConnectOptions::new(url) + .max_connections(*max_connections) + .to_owned(); + Ok(Database::connect(options).await?) + } + } +} diff --git a/server/src/main.rs b/server/src/main.rs index c3237cf..f7e1a95 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -6,6 +6,7 @@ mod distro; mod error; mod repo; +pub use config::{Config, DbConfig, FsConfig}; pub use error::{Result, ServerError}; use repo::DistroMgr; @@ -14,14 +15,9 @@ use std::{path::PathBuf, sync::Arc}; pub const ANY_ARCH: &'static str = "any"; -#[derive(Clone)] -pub struct Config { - data_dir: PathBuf, -} - #[derive(Clone)] pub struct Global { - config: Config, + config: crate::config::Config, mgr: Arc, db: sea_orm::DbConn, } diff --git a/server/src/repo/manager2.rs b/server/src/repo/manager2.rs index 266eeee..2f66cfe 100644 --- a/server/src/repo/manager2.rs +++ b/server/src/repo/manager2.rs @@ -248,7 +248,7 @@ impl RepoMgr { } pub async fn queue_pkg(&self, repo: i32, path: PathBuf) { - let _ = self.pkg_queue.0.send(PkgQueueMsg { path, repo }); + self.pkg_queue.0.send(PkgQueueMsg { path, repo }).unwrap(); self.repos.read().await.get(&repo).inspect(|n| { n.0.fetch_add(1, Ordering::SeqCst); }); @@ -291,6 +291,7 @@ impl RepoMgr { }; let repo_id: Option = db::Repo::find() + .filter(db::repo::Column::DistroId.eq(distro_id)) .filter(db::repo::Column::Name.eq(repo)) .select_only() .column(db::repo::Column::Id) diff --git a/server/src/repo/mod.rs b/server/src/repo/mod.rs index d088095..16c62a5 100644 --- a/server/src/repo/mod.rs +++ b/server/src/repo/mod.rs @@ -6,6 +6,8 @@ pub mod package; pub use manager::DistroMgr; pub use manager2::RepoMgr; +use crate::FsConfig; + use axum::{ body::Body, extract::{Path, State}, @@ -50,25 +52,26 @@ async fn get_file( req: Request, ) -> crate::Result { if let Some(repo_id) = global.mgr.get_repo(&distro, &repo).await? { - let repo_dir = global - .config - .data_dir - .join("repos") - .join(repo_id.to_string()); + match global.config.fs { + FsConfig::Local { data_dir } => { + let repo_dir = data_dir.join("repos").join(repo_id.to_string()); - let file_name = - if file_name == format!("{}.db", repo) || file_name == format!("{}.db.tar.gz", repo) { - format!("{}.db.tar.gz", arch) - } else if file_name == format!("{}.files", repo) - || file_name == format!("{}.files.tar.gz", repo) - { - format!("{}.files.tar.gz", arch) - } else { - file_name - }; + let file_name = if file_name == format!("{}.db", repo) + || file_name == format!("{}.db.tar.gz", repo) + { + format!("{}.db.tar.gz", arch) + } else if file_name == format!("{}.files", repo) + || file_name == format!("{}.files.tar.gz", repo) + { + format!("{}.files.tar.gz", arch) + } else { + file_name + }; - let path = repo_dir.join(file_name); - Ok(ServeFile::new(path).oneshot(req).await) + let path = repo_dir.join(file_name); + Ok(ServeFile::new(path).oneshot(req).await) + } + } } else { Err(StatusCode::NOT_FOUND.into()) }