diff --git a/Rb.yaml b/Rb.yaml index 6b40537..c944758 100644 --- a/Rb.yaml +++ b/Rb.yaml @@ -16,7 +16,7 @@ debug: key: "secret" refresh_token_size: 64 # Just 5 seconds for debugging - refresh_token_expire: 5 + refresh_token_expire: 60 databases: postgres_rb: diff --git a/rustfmt.toml b/rustfmt.toml index 03acab0..5e52857 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -49,7 +49,7 @@ reorder_imports = true reorder_modules = true report_fixme = "Always" report_todo = "Always" -required_version = "1.4.36" +required_version = "1.4.37" skip_children = false space_after_colon = true space_before_colon = false diff --git a/src/errors.rs b/src/errors.rs index 2257fa5..bb7856a 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,6 +16,7 @@ pub enum RbError AuthRefreshTokenExpired, AuthInvalidRefreshToken, AuthDuplicateRefreshToken, + AuthMissingHeader, // UM = User Management UMDuplicateUser, @@ -39,6 +40,7 @@ impl RbError RbError::AuthRefreshTokenExpired => Status::Unauthorized, RbError::AuthInvalidRefreshToken => Status::Unauthorized, RbError::AuthDuplicateRefreshToken => Status::Unauthorized, + RbError::AuthMissingHeader => Status::BadRequest, RbError::UMDuplicateUser => Status::Conflict, @@ -60,6 +62,7 @@ impl RbError RbError::AuthDuplicateRefreshToken => { "This refresh token has already been used. The user has been blocked." } + RbError::AuthMissingHeader => "Missing Authorization header.", RbError::UMDuplicateUser => "This user already exists.", diff --git a/src/guards.rs b/src/guards.rs index 532ba97..7b40bdd 100644 --- a/src/guards.rs +++ b/src/guards.rs @@ -4,10 +4,11 @@ use rocket::{ http::Status, outcome::try_outcome, request::{FromRequest, Outcome, Request}, + State, }; use sha2::Sha256; -use crate::auth::jwt::Claims; +use crate::{auth::jwt::Claims, errors::RbError, RbConfig}; /// Extracts a "Authorization: Bearer" string from the headers. pub struct Bearer<'a>(&'a str); @@ -21,7 +22,7 @@ impl<'r> FromRequest<'r> for Bearer<'r> { // If the header isn't present, just forward to the next route let header = match req.headers().get_one("Authorization") { - None => return Outcome::Forward(()), + None => return Outcome::Failure((Status::BadRequest, Self::Error::AuthMissingHeader)), Some(val) => val, }; @@ -32,7 +33,7 @@ impl<'r> FromRequest<'r> for Bearer<'r> // Extract the jwt token from the header let auth_string = match header.get(7..) { Some(s) => s, - None => return Outcome::Forward(()), + None => return Outcome::Failure((Status::Unauthorized, Self::Error::AuthUnauthorized)), }; Outcome::Success(Self(auth_string)) @@ -45,23 +46,17 @@ pub struct Jwt(Claims); #[rocket::async_trait] impl<'r> FromRequest<'r> for Jwt { - type Error = crate::errors::RbError; + type Error = RbError; async fn from_request(req: &'r Request<'_>) -> Outcome { let bearer = try_outcome!(req.guard::().await).0; + let config = try_outcome!(req.guard::<&State>().await.map_failure(|_| ( + Status::InternalServerError, + RbError::Custom("Couldn't get config guard.") + ))); - // Get secret & key - let secret = match std::env::var("JWT_KEY") { - Ok(key) => key, - Err(_) => { - return Outcome::Failure(( - Status::InternalServerError, - Self::Error::AuthUnauthorized, - )) - } - }; - let key: Hmac = match Hmac::new_from_slice(secret.as_bytes()) { + let key: Hmac = match Hmac::new_from_slice(&config.jwt.key.as_bytes()) { Ok(key) => key, Err(_) => { return Outcome::Failure(( @@ -118,7 +113,7 @@ impl<'r> FromRequest<'r> for Admin if user.admin { Outcome::Success(Self(user)) } else { - Outcome::Forward(()) + Outcome::Failure((Status::Unauthorized, RbError::AuthUnauthorized)) } } } diff --git a/src/main.rs b/src/main.rs index fa147e9..d4ee778 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,12 @@ use figment::{ providers::{Env, Format, Yaml}, Figment, }; -use rocket::{fairing::AdHoc, Build, Rocket}; +use rocket::{ + fairing::AdHoc, + http::Status, + serde::json::{json, Value}, + Build, Request, Rocket, +}; use rocket_sync_db_pools::database; use serde::{Deserialize, Serialize}; @@ -26,6 +31,12 @@ pub(crate) mod schema; #[database("postgres_rb")] pub struct RbDbConn(diesel::PgConnection); +#[catch(default)] +fn default_catcher(status: Status, _: &Request) -> Value +{ + json!({"status": status.code, "message": ""}) +} + embed_migrations!(); async fn run_db_migrations(rocket: Rocket) -> Result, Rocket> @@ -88,6 +99,7 @@ fn rocket() -> _ )) .attach(AdHoc::try_on_ignite("Create admin user", create_admin_user)) .attach(AdHoc::config::()) + .register("/", catchers![default_catcher]) .mount( "/api/auth", routes![auth::already_logged_in, auth::login, auth::refresh_token,], diff --git a/tests/admin.py b/tests/admin.py new file mode 100644 index 0000000..19a1c5b --- /dev/null +++ b/tests/admin.py @@ -0,0 +1,64 @@ +import requests + + +class RbClient: + def __init__(self, username, password, base_url = "http://localhost:8000/api"): + self.username = username + self.password = password + self.base_url = base_url + + self.jwt = None + self.refresh_token = None + + def _login(self): + r = requests.post(f"{self.base_url}/auth/login", json={ + "username": self.username, + "password": self.password, + }) + + if r.status_code != 200: + raise Exception("Couldn't login") + + res = r.json() + self.jwt = res["token"] + self.refresh_token = res["refreshToken"] + + def _refresh(self): + r = requests.post(f"{self.base_url}/auth/refresh", json={"refreshToken": self.refresh_token}) + + if r.status_code != 200: + raise Exception("Couldn't refresh") + + res = r.json() + self.jwt = res["token"] + self.refresh_token = res["refreshToken"] + + def _request(self, type_, url, retry=2, *args, **kwargs): + if self.jwt: + headers = kwargs.get("headers", {}) + headers["Authorization"] = f"Bearer {self.jwt}" + kwargs["headers"] = headers + print(kwargs["headers"]) + + r = requests.request(type_, url, *args, **kwargs) + + if r.status_code != 200 and retry > 0: + if self.refresh_token: + self._refresh() + + else: + self._login() + + r = self._request(type_, url, *args, **kwargs, retry=retry - 1) + + return r + + def get(self, url, *args, **kwargs): + return self._request("GET", f"{self.base_url}{url}", *args, **kwargs) + + + +if __name__ == "__main__": + client = RbClient("admin", "password") + + print(client.get("/admin/users").json())