diff --git a/src/server/auth.rs b/src/server/auth.rs index eb31aea..61a68f9 100644 --- a/src/server/auth.rs +++ b/src/server/auth.rs @@ -13,37 +13,41 @@ use axum_extra::extract::{ use serde::Deserialize; use tera::Context; -use crate::db::{Session, User}; +use crate::db::{DbError, DbPool, Session, User}; use super::{error::AppError, render_view}; +pub fn logged_in_user(pool: &DbPool, headers: &HeaderMap) -> Result, DbError> { + let jar = CookieJar::from_headers(headers); + + if let Some(session_id) = jar + .get("session_id") + .and_then(|c| c.value().parse::().ok()) + { + Session::user_from_id(pool, session_id) + } else { + Ok(None) + } +} + pub async fn auth_middleware( State(ctx): State, mut req: Request, next: Next, ) -> Response { - let jar = CookieJar::from_headers(req.headers()); + match logged_in_user(&ctx.pool, req.headers()) { + Ok(Some(user)) => { + req.extensions_mut().insert(user); - if let Some(session_id) = jar - .get("session_id") - .and_then(|c| c.value().parse::().ok()) - { - match Session::user_from_id(&ctx.pool, session_id) { - Ok(Some(user)) => { - req.extensions_mut().insert(user); - - next.run(req).await - } - Ok(None) => StatusCode::UNAUTHORIZED.into_response(), - Err(err) => AppError::Db(err).into_response(), + next.run(req).await } - } else { - StatusCode::UNAUTHORIZED.into_response() + Ok(None) => StatusCode::UNAUTHORIZED.into_response(), + Err(err) => AppError::Db(err).into_response(), } } pub fn app() -> Router { - Router::new().route("/login", get(get_login).post(post_login)) + Router::new().route("/login", post(post_login)) } #[derive(Deserialize)] @@ -52,23 +56,10 @@ struct Login { password: String, } -async fn get_login( - State(ctx): State, - headers: HeaderMap, -) -> Result, AppError> { - let context = Context::new(); - - Ok(Html(render_view( - &ctx.tera, - "views/login.html", - &context, - &headers, - )?)) -} - async fn post_login( State(ctx): State, jar: CookieJar, + headers: HeaderMap, Form(login): Form, ) -> Result<(CookieJar, Html), AppError> { if let Some(user) = User::by_username(&ctx.pool, &login.username)? { @@ -84,7 +75,7 @@ async fn post_login( .same_site(SameSite::Lax) .build(), ), - Html(String::new()), + super::render_home(ctx, &headers).await?, )) } else { Err(AppError::Unauthorized) diff --git a/src/server/mod.rs b/src/server/mod.rs index c7251bc..70af08f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -74,7 +74,7 @@ pub fn app(ctx: crate::Context, static_dir: &str) -> axum::Router { )) } -async fn get_index(State(ctx): State, headers: HeaderMap) -> Result> { +pub async fn render_home(ctx: crate::Context, headers: &HeaderMap) -> Result> { let plants = tokio::task::spawn_blocking(move || Plant::all(&ctx.pool)) .await .unwrap()?; @@ -85,6 +85,25 @@ async fn get_index(State(ctx): State, headers: HeaderMap) -> Res &ctx.tera, "views/index.html", &context, + headers, + )?)) +} + +pub fn render_login(ctx: crate::Context, headers: &HeaderMap) -> Result> { + let context = Context::new(); + + Ok(Html(render_view( + &ctx.tera, + "views/login.html", + &context, &headers, )?)) } + +async fn get_index(State(ctx): State, headers: HeaderMap) -> Result> { + if auth::logged_in_user(&ctx.pool, &headers)?.is_some() { + render_home(ctx, &headers).await + } else { + render_login(ctx, &headers) + } +}