diff --git a/src/cli/serve.rs b/src/cli/serve.rs index f329858..1ee2c7d 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -32,7 +32,10 @@ impl ServeCommand { let pool = db::initialize_db(cli.data_dir.join(crate::DB_FILENAME), true).unwrap(); - let ctx = server::Context { pool }; + let ctx = server::Context { + pool: pool.clone(), + repo: db::SqliteRepository::from(pool.clone()), + }; let app = server::app(ctx); let rt = tokio::runtime::Builder::new_multi_thread() diff --git a/src/db/mod.rs b/src/db/mod.rs index 8aa2823..3f2ec04 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,4 +1,5 @@ pub mod models; +mod repository; mod schema; pub use models::device::{Device, DeviceType, NewDevice}; @@ -6,6 +7,8 @@ pub use models::session::Session; pub use models::subscription::{NewSubscription, Subscription}; pub use models::user::{NewUser, User}; +pub use repository::SqliteRepository; + use diesel::{ r2d2::{ConnectionManager, Pool}, SqliteConnection, diff --git a/src/db/repository/auth.rs b/src/db/repository/auth.rs new file mode 100644 index 0000000..72a3956 --- /dev/null +++ b/src/db/repository/auth.rs @@ -0,0 +1,75 @@ +use diesel::prelude::*; +use rand::Rng; + +use super::SqliteRepository; +use crate::{ + db::{self, schema::*}, + gpodder, +}; + +impl From for gpodder::AuthErr { + fn from(value: diesel::r2d2::PoolError) -> Self { + Self::Other(Box::new(value)) + } +} + +impl From for gpodder::AuthErr { + fn from(value: diesel::result::Error) -> Self { + Self::Other(Box::new(value)) + } +} + +impl gpodder::AuthRepository for SqliteRepository { + fn validate_session(&self, session_id: i64) -> Result { + match sessions::dsl::sessions + .inner_join(users::table) + .filter(sessions::id.eq(session_id)) + .select(db::User::as_select()) + .get_result(&mut self.pool.get()?) + { + Ok(user) => Ok(gpodder::User { + id: user.id, + username: user.username, + }), + Err(diesel::result::Error::NotFound) => Err(gpodder::AuthErr::UnknownSession), + Err(err) => Err(gpodder::AuthErr::Other(Box::new(err))), + } + } + + fn create_session( + &self, + username: &str, + password: &str, + ) -> Result<(i64, gpodder::models::User), gpodder::AuthErr> { + if let Some(user) = users::table + .select(db::User::as_select()) + .filter(users::username.eq(username)) + .first(&mut self.pool.get()?) + .optional()? + { + if user.verify_password(password) { + let id: i64 = rand::thread_rng().gen(); + + let session_id = db::Session { + id, + user_id: user.id, + } + .insert_into(sessions::table) + .returning(sessions::id) + .get_result(&mut self.pool.get()?)?; + + Ok(( + session_id, + gpodder::User { + id: user.id, + username: user.username, + }, + )) + } else { + Err(gpodder::AuthErr::InvalidPassword) + } + } else { + Err(gpodder::AuthErr::UnknownUser) + } + } +} diff --git a/src/db/repository/mod.rs b/src/db/repository/mod.rs new file mode 100644 index 0000000..84d2415 --- /dev/null +++ b/src/db/repository/mod.rs @@ -0,0 +1,14 @@ +mod auth; + +use super::DbPool; + +#[derive(Clone)] +pub struct SqliteRepository { + pool: DbPool, +} + +impl From for SqliteRepository { + fn from(value: DbPool) -> Self { + Self { pool: value } + } +} diff --git a/src/gpodder/mod.rs b/src/gpodder/mod.rs new file mode 100644 index 0000000..6e5e702 --- /dev/null +++ b/src/gpodder/mod.rs @@ -0,0 +1,26 @@ +pub mod models; + +pub use models::*; + +pub enum AuthErr { + UnknownSession, + UnknownUser, + InvalidPassword, + Other(Box), +} + +pub trait AuthRepository: Send + Sync { + /// Validate the given session ID and return its user. + fn validate_session(&self, session_id: i64) -> Result; + + /// Create a new session for the given user. + fn create_session( + &self, + username: &str, + password: &str, + ) -> Result<(i64, models::User), AuthErr>; +} + +// pub trait DeviceRepository: Send + Sync { +// fn devices_for_user(&self, ) +// } diff --git a/src/gpodder/models.rs b/src/gpodder/models.rs new file mode 100644 index 0000000..764b33f --- /dev/null +++ b/src/gpodder/models.rs @@ -0,0 +1,4 @@ +pub struct User { + pub id: i64, + pub username: String, +} diff --git a/src/main.rs b/src/main.rs index b6136c2..b418992 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod cli; mod db; +mod gpodder; mod server; use clap::Parser; diff --git a/src/server/mod.rs b/src/server/mod.rs index a1cbabb..2804e7a 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -7,6 +7,7 @@ use tower_http::trace::TraceLayer; #[derive(Clone)] pub struct Context { pub pool: crate::db::DbPool, + pub repo: crate::db::SqliteRepository, } pub fn app(ctx: Context) -> Router {