diff --git a/Cargo.lock b/Cargo.lock index e78c074..88e0186 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,6 +89,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fc6f625a1f7705c6cf62d0d070794e94668988b1c38111baeec177c715f7b" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie", + "futures-util", + "headers", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -104,6 +127,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64ct" version = "1.6.0" @@ -161,6 +190,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -387,6 +427,30 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.5.0" @@ -631,6 +695,7 @@ version = "0.1.0" dependencies = [ "argon2", "axum", + "axum-extra", "diesel", "diesel_migrations", "libsqlite3-sys", @@ -885,6 +950,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" diff --git a/Cargo.toml b/Cargo.toml index d9889d2..560e28b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] argon2 = "0.5.3" axum = "0.8.1" +axum-extra = { version = "0.10", features = ["cookie", "typed-header"] } diesel = { version = "2.2.7", features = ["r2d2", "sqlite", "returning_clauses_for_sqlite_3_35"] } diesel_migrations = { version = "2.2.0", features = ["sqlite"] } libsqlite3-sys = { version = "0.31.0", features = ["bundled"] } diff --git a/src/db/mod.rs b/src/db/mod.rs index b41edf9..63d6ef0 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,6 +1,9 @@ -mod models; +pub mod models; mod schema; +pub use models::session::Session; +pub use models::user::{NewUser, User}; + use diesel::{ r2d2::{ConnectionManager, Pool}, SqliteConnection, diff --git a/src/server/error.rs b/src/server/error.rs new file mode 100644 index 0000000..4e84dc9 --- /dev/null +++ b/src/server/error.rs @@ -0,0 +1,86 @@ +use std::fmt::{self, Write}; + +use axum::{http::StatusCode, response::IntoResponse}; + +use crate::db; + +pub type AppResult = Result; + +#[derive(Debug)] +pub enum AppError { + Db(db::DbError), + IO(std::io::Error), + Other(Box), + BadRequest, + Unauthorized, + NotFound, +} + +impl fmt::Display for AppError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Db(_) => write!(f, "database error"), + Self::IO(_) => write!(f, "io error"), + Self::Other(_) => write!(f, "other error"), + Self::BadRequest => write!(f, "bad request"), + Self::Unauthorized => write!(f, "unauthorized"), + Self::NotFound => write!(f, "not found"), + } + } +} + +impl std::error::Error for AppError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Db(err) => Some(err), + Self::IO(err) => Some(err), + Self::Other(err) => Some(err.as_ref()), + Self::NotFound | Self::Unauthorized | Self::BadRequest => None, + } + } +} + +pub trait ErrorExt: std::error::Error { + /// Return the full chain of error messages + fn stack(&self) -> String { + let mut msg = format!("{}", self); + let mut err = self.source(); + + while let Some(src) = err { + write!(msg, " - {}", src).unwrap(); + + err = src.source(); + } + + msg + } +} + +impl ErrorExt for E {} + +impl From for AppError { + fn from(value: db::DbError) -> Self { + Self::Db(value) + } +} + +impl From for AppError { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} + +impl IntoResponse for AppError { + fn into_response(self) -> axum::response::Response { + match self { + Self::NotFound => StatusCode::NOT_FOUND.into_response(), + Self::Unauthorized => StatusCode::UNAUTHORIZED.into_response(), + Self::BadRequest => StatusCode::BAD_REQUEST.into_response(), + _ => { + tracing::error!("{}", self.stack()); + + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } + } +} diff --git a/src/server/gpodder/auth.rs b/src/server/gpodder/auth.rs new file mode 100644 index 0000000..c907a77 --- /dev/null +++ b/src/server/gpodder/auth.rs @@ -0,0 +1,51 @@ +use axum::{ + extract::{Path, State}, + routing::post, + Router, +}; +use axum_extra::{ + extract::{ + cookie::{Cookie, Expiration}, + CookieJar, + }, + headers::{authorization::Basic, Authorization}, + TypedHeader, +}; + +use crate::{ + db::{Session, User}, + server::{ + error::{AppError, AppResult}, + Context, + }, +}; + +pub fn router() -> Router { + Router::new().route("/{username}/login.json", post(post_login)) +} + +async fn post_login( + State(ctx): State, + Path(username): Path, + jar: CookieJar, + TypedHeader(auth): TypedHeader>, +) -> AppResult { + // These should be the same according to the spec + if username != auth.username() { + return Err(AppError::BadRequest); + } + + let session = tokio::task::spawn_blocking(move || { + let user = User::by_username(&ctx.pool, auth.username())?.ok_or(AppError::NotFound)?; + + if user.verify_password(auth.password()) { + Ok(Session::new_for_user(&ctx.pool, user.id)?) + } else { + Err(AppError::Unauthorized) + } + }) + .await + .unwrap()?; + + Ok(jar.add(Cookie::build(("sessionid", session.id.to_string())).expires(Expiration::Session))) +} diff --git a/src/server/gpodder/mod.rs b/src/server/gpodder/mod.rs new file mode 100644 index 0000000..690780b --- /dev/null +++ b/src/server/gpodder/mod.rs @@ -0,0 +1,21 @@ +mod auth; + +use axum::{ + http::{HeaderName, HeaderValue}, + Router, +}; +use tower_http::set_header::SetResponseHeaderLayer; + +use super::Context; + +pub fn router() -> Router { + Router::new() + .nest("/auth", auth::router()) + // https://gpoddernet.readthedocs.io/en/latest/api/reference/general.html#cors + // All endpoints should send this CORS header value so the endpoints can be used from web + // applications + .layer(SetResponseHeaderLayer::overriding( + HeaderName::from_static("access-control-allow-origin"), + HeaderValue::from_static("*"), + )) +} diff --git a/src/server/mod.rs b/src/server/mod.rs index a285a62..2ebbc5f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,8 +1,8 @@ -use axum::{ - http::{HeaderName, HeaderValue}, - Router, -}; -use tower_http::{set_header::SetResponseHeaderLayer, trace::TraceLayer}; +mod error; +mod gpodder; + +use axum::Router; +use tower_http::trace::TraceLayer; #[derive(Clone)] pub struct Context { @@ -11,12 +11,6 @@ pub struct Context { pub fn app() -> Router { Router::new() + .nest("/api/2", gpodder::router()) .layer(TraceLayer::new_for_http()) - // https://gpoddernet.readthedocs.io/en/latest/api/reference/general.html#cors - // All endpoints should send this CORS header value so the endpoints can be used from web - // applications - .layer(SetResponseHeaderLayer::overriding( - HeaderName::from_static("access-control-allow-origin"), - HeaderValue::from_static("*"), - )) }