diff --git a/Cargo.lock b/Cargo.lock index 3ab3d68..7ebc587 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,7 +106,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "axum-macros", "bytes", "futures-util", @@ -116,7 +116,7 @@ dependencies = [ "hyper", "hyper-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -134,6 +134,33 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" +dependencies = [ + "axum-core 0.5.0", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -155,6 +182,48 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-core" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-extra" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fc6f625a1f7705c6cf62d0d070794e94668988b1c38111baeec177c715f7b" +dependencies = [ + "axum 0.8.1", + "axum-core 0.5.0", + "bytes", + "cookie", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-macros" version = "0.4.2" @@ -249,7 +318,8 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" name = "calathea" version = "0.1.0" dependencies = [ - "axum", + "axum 0.7.9", + "axum-extra", "chrono", "r2d2", "r2d2_sqlite", @@ -314,6 +384,17 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -373,6 +454,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "deunicode" version = "1.6.0" @@ -736,6 +826,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.7.4" @@ -788,6 +884,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-traits" version = "0.2.19" @@ -957,6 +1059,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1321,6 +1429,37 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tokio" version = "1.42.0" diff --git a/Cargo.toml b/Cargo.toml index b4079ae..09669e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ name = "calathea" [dependencies] axum = { version = "0.7.9", features = ["macros"] } +axum-extra = { version = "0.10.0", features = ["cookie"] } chrono = { version = "0.4.39", features = ["serde"] } r2d2 = "0.8.10" r2d2_sqlite = "0.25.0" diff --git a/src/db/mod.rs b/src/db/mod.rs index 1db92fe..c51a94c 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,6 +1,7 @@ mod comment; mod event; mod plant; +mod session; mod user; use r2d2_sqlite::{rusqlite, SqliteConnectionManager}; @@ -10,6 +11,8 @@ use std::{error::Error, fmt}; pub use comment::{Comment, NewComment}; pub use event::{Event, EventType, NewEvent, EVENT_TYPES}; pub use plant::{NewPlant, Plant}; +pub use session::Session; +pub use user::{NewUser, User}; pub type DbPool = r2d2::Pool; diff --git a/src/db/session.rs b/src/db/session.rs new file mode 100644 index 0000000..5264407 --- /dev/null +++ b/src/db/session.rs @@ -0,0 +1,19 @@ +use super::{DbError, DbPool, User}; + +pub struct Session { + id: u64, + user_id: i32, +} + +impl Session { + pub fn user_from_id(pool: &DbPool, id: u64) -> Result, DbError> { + let conn = pool.get()?; + + let mut stmt = conn.prepare("select users.* from sessions inner join users on sessions.user_id = users.id where sessions.id = $1")?; + match stmt.query_row((id,), User::from_row) { + Ok(user) => Ok(Some(user)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(err) => Err(DbError::Db(err)), + } + } +} diff --git a/src/db/user.rs b/src/db/user.rs index 4a8c333..949ea79 100644 --- a/src/db/user.rs +++ b/src/db/user.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use super::{DbError, DbPool}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct User { id: i32, username: String, diff --git a/src/main.rs b/src/main.rs index 62fda66..f44edef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,11 +7,12 @@ use r2d2_sqlite::SqliteConnectionManager; use tera::Tera; use tower_http::compression::CompressionLayer; -const MIGRATIONS: [&str; 4] = [ +const MIGRATIONS: [&str; 5] = [ include_str!("migrations/000_initial.sql"), include_str!("migrations/001_plants.sql"), include_str!("migrations/002_comments.sql"), include_str!("migrations/003_events.sql"), + include_str!("migrations/004_auth.sql"), ]; #[derive(Clone)] diff --git a/src/server/auth.rs b/src/server/auth.rs new file mode 100644 index 0000000..8ff8151 --- /dev/null +++ b/src/server/auth.rs @@ -0,0 +1,36 @@ +use axum::{ + extract::{Request, State}, + http::StatusCode, + middleware::Next, + response::{IntoResponse, Response}, +}; +use axum_extra::extract::CookieJar; + +use crate::db::Session; + +use super::error::AppError; + +pub async fn auth_middleware( + State(ctx): State, + mut req: Request, + next: Next, +) -> Response { + let jar = CookieJar::from_headers(req.headers()); + + if let Some(session_id) = jar + .get("session_id") + .and_then(|c| c.value().parse::().ok()) + { + match Session::user_from_id(&ctx.pool, session_id) { + Ok(Some(user)) => { + req.extensions_mut().insert(user); + + next.run(req).await + } + Ok(None) => StatusCode::UNAUTHORIZED.into_response(), + Err(err) => AppError::Db(err).into_response(), + } + } else { + StatusCode::UNAUTHORIZED.into_response() + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 56e156e..606ae95 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,3 +1,4 @@ +mod auth; mod comments; mod error; mod events; @@ -6,6 +7,7 @@ mod plants; use axum::{ extract::State, http::{header::VARY, HeaderMap, HeaderValue}, + middleware, response::Html, routing::get, Router, @@ -50,10 +52,14 @@ pub fn render_view( pub fn app(ctx: crate::Context, static_dir: &str) -> axum::Router { let router = Router::new() - .route("/", get(get_index)) .nest("/plants", plants::app()) .nest("/comments", comments::app()) .nest("/events", events::app()) + .layer(middleware::from_fn_with_state( + ctx.clone(), + auth::auth_middleware, + )) + .route("/", get(get_index)) .nest_service("/static", ServeDir::new(static_dir)) .with_state(ctx.clone());