feat: added error handling and login POST route
This commit is contained in:
parent
67ad8c2b64
commit
2f8181491a
7 changed files with 245 additions and 13 deletions
|
|
@ -1,6 +1,9 @@
|
|||
mod models;
|
||||
pub mod models;
|
||||
mod schema;
|
||||
|
||||
pub use models::session::Session;
|
||||
pub use models::user::{NewUser, User};
|
||||
|
||||
use diesel::{
|
||||
r2d2::{ConnectionManager, Pool},
|
||||
SqliteConnection,
|
||||
|
|
|
|||
86
src/server/error.rs
Normal file
86
src/server/error.rs
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
use std::fmt::{self, Write};
|
||||
|
||||
use axum::{http::StatusCode, response::IntoResponse};
|
||||
|
||||
use crate::db;
|
||||
|
||||
pub type AppResult<T> = Result<T, AppError>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AppError {
|
||||
Db(db::DbError),
|
||||
IO(std::io::Error),
|
||||
Other(Box<dyn std::error::Error + 'static + Send + Sync>),
|
||||
BadRequest,
|
||||
Unauthorized,
|
||||
NotFound,
|
||||
}
|
||||
|
||||
impl fmt::Display for AppError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Db(_) => write!(f, "database error"),
|
||||
Self::IO(_) => write!(f, "io error"),
|
||||
Self::Other(_) => write!(f, "other error"),
|
||||
Self::BadRequest => write!(f, "bad request"),
|
||||
Self::Unauthorized => write!(f, "unauthorized"),
|
||||
Self::NotFound => write!(f, "not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AppError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Db(err) => Some(err),
|
||||
Self::IO(err) => Some(err),
|
||||
Self::Other(err) => Some(err.as_ref()),
|
||||
Self::NotFound | Self::Unauthorized | Self::BadRequest => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ErrorExt: std::error::Error {
|
||||
/// Return the full chain of error messages
|
||||
fn stack(&self) -> String {
|
||||
let mut msg = format!("{}", self);
|
||||
let mut err = self.source();
|
||||
|
||||
while let Some(src) = err {
|
||||
write!(msg, " - {}", src).unwrap();
|
||||
|
||||
err = src.source();
|
||||
}
|
||||
|
||||
msg
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: std::error::Error> ErrorExt for E {}
|
||||
|
||||
impl From<db::DbError> for AppError {
|
||||
fn from(value: db::DbError) -> Self {
|
||||
Self::Db(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for AppError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self::IO(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::NotFound => StatusCode::NOT_FOUND.into_response(),
|
||||
Self::Unauthorized => StatusCode::UNAUTHORIZED.into_response(),
|
||||
Self::BadRequest => StatusCode::BAD_REQUEST.into_response(),
|
||||
_ => {
|
||||
tracing::error!("{}", self.stack());
|
||||
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
51
src/server/gpodder/auth.rs
Normal file
51
src/server/gpodder/auth.rs
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use axum_extra::{
|
||||
extract::{
|
||||
cookie::{Cookie, Expiration},
|
||||
CookieJar,
|
||||
},
|
||||
headers::{authorization::Basic, Authorization},
|
||||
TypedHeader,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
db::{Session, User},
|
||||
server::{
|
||||
error::{AppError, AppResult},
|
||||
Context,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn router() -> Router<Context> {
|
||||
Router::new().route("/{username}/login.json", post(post_login))
|
||||
}
|
||||
|
||||
async fn post_login(
|
||||
State(ctx): State<Context>,
|
||||
Path(username): Path<String>,
|
||||
jar: CookieJar,
|
||||
TypedHeader(auth): TypedHeader<Authorization<Basic>>,
|
||||
) -> AppResult<CookieJar> {
|
||||
// These should be the same according to the spec
|
||||
if username != auth.username() {
|
||||
return Err(AppError::BadRequest);
|
||||
}
|
||||
|
||||
let session = tokio::task::spawn_blocking(move || {
|
||||
let user = User::by_username(&ctx.pool, auth.username())?.ok_or(AppError::NotFound)?;
|
||||
|
||||
if user.verify_password(auth.password()) {
|
||||
Ok(Session::new_for_user(&ctx.pool, user.id)?)
|
||||
} else {
|
||||
Err(AppError::Unauthorized)
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
Ok(jar.add(Cookie::build(("sessionid", session.id.to_string())).expires(Expiration::Session)))
|
||||
}
|
||||
21
src/server/gpodder/mod.rs
Normal file
21
src/server/gpodder/mod.rs
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
mod auth;
|
||||
|
||||
use axum::{
|
||||
http::{HeaderName, HeaderValue},
|
||||
Router,
|
||||
};
|
||||
use tower_http::set_header::SetResponseHeaderLayer;
|
||||
|
||||
use super::Context;
|
||||
|
||||
pub fn router() -> Router<Context> {
|
||||
Router::new()
|
||||
.nest("/auth", auth::router())
|
||||
// https://gpoddernet.readthedocs.io/en/latest/api/reference/general.html#cors
|
||||
// All endpoints should send this CORS header value so the endpoints can be used from web
|
||||
// applications
|
||||
.layer(SetResponseHeaderLayer::overriding(
|
||||
HeaderName::from_static("access-control-allow-origin"),
|
||||
HeaderValue::from_static("*"),
|
||||
))
|
||||
}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
use axum::{
|
||||
http::{HeaderName, HeaderValue},
|
||||
Router,
|
||||
};
|
||||
use tower_http::{set_header::SetResponseHeaderLayer, trace::TraceLayer};
|
||||
mod error;
|
||||
mod gpodder;
|
||||
|
||||
use axum::Router;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Context {
|
||||
|
|
@ -11,12 +11,6 @@ pub struct Context {
|
|||
|
||||
pub fn app() -> Router<Context> {
|
||||
Router::new()
|
||||
.nest("/api/2", gpodder::router())
|
||||
.layer(TraceLayer::new_for_http())
|
||||
// https://gpoddernet.readthedocs.io/en/latest/api/reference/general.html#cors
|
||||
// All endpoints should send this CORS header value so the endpoints can be used from web
|
||||
// applications
|
||||
.layer(SetResponseHeaderLayer::overriding(
|
||||
HeaderName::from_static("access-control-allow-origin"),
|
||||
HeaderValue::from_static("*"),
|
||||
))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue