diff --git a/Rb.yaml b/Rb.yaml index c944758..6b40537 100644 --- a/Rb.yaml +++ b/Rb.yaml @@ -16,7 +16,7 @@ debug: key: "secret" refresh_token_size: 64 # Just 5 seconds for debugging - refresh_token_expire: 60 + refresh_token_expire: 5 databases: postgres_rb: diff --git a/rustfmt.toml b/rustfmt.toml index 5e52857..03acab0 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -49,7 +49,7 @@ reorder_imports = true reorder_modules = true report_fixme = "Always" report_todo = "Always" -required_version = "1.4.37" +required_version = "1.4.36" skip_children = false space_after_colon = true space_before_colon = false diff --git a/src/errors.rs b/src/errors.rs index bb7856a..2257fa5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,7 +16,6 @@ pub enum RbError AuthRefreshTokenExpired, AuthInvalidRefreshToken, AuthDuplicateRefreshToken, - AuthMissingHeader, // UM = User Management UMDuplicateUser, @@ -40,7 +39,6 @@ impl RbError RbError::AuthRefreshTokenExpired => Status::Unauthorized, RbError::AuthInvalidRefreshToken => Status::Unauthorized, RbError::AuthDuplicateRefreshToken => Status::Unauthorized, - RbError::AuthMissingHeader => Status::BadRequest, RbError::UMDuplicateUser => Status::Conflict, @@ -62,7 +60,6 @@ impl RbError RbError::AuthDuplicateRefreshToken => { "This refresh token has already been used. The user has been blocked." } - RbError::AuthMissingHeader => "Missing Authorization header.", RbError::UMDuplicateUser => "This user already exists.", diff --git a/src/guards.rs b/src/guards.rs index 7b40bdd..532ba97 100644 --- a/src/guards.rs +++ b/src/guards.rs @@ -4,11 +4,10 @@ use rocket::{ http::Status, outcome::try_outcome, request::{FromRequest, Outcome, Request}, - State, }; use sha2::Sha256; -use crate::{auth::jwt::Claims, errors::RbError, RbConfig}; +use crate::auth::jwt::Claims; /// Extracts a "Authorization: Bearer" string from the headers. pub struct Bearer<'a>(&'a str); @@ -22,7 +21,7 @@ impl<'r> FromRequest<'r> for Bearer<'r> { // If the header isn't present, just forward to the next route let header = match req.headers().get_one("Authorization") { - None => return Outcome::Failure((Status::BadRequest, Self::Error::AuthMissingHeader)), + None => return Outcome::Forward(()), Some(val) => val, }; @@ -33,7 +32,7 @@ impl<'r> FromRequest<'r> for Bearer<'r> // Extract the jwt token from the header let auth_string = match header.get(7..) { Some(s) => s, - None => return Outcome::Failure((Status::Unauthorized, Self::Error::AuthUnauthorized)), + None => return Outcome::Forward(()), }; Outcome::Success(Self(auth_string)) @@ -46,17 +45,23 @@ pub struct Jwt(Claims); #[rocket::async_trait] impl<'r> FromRequest<'r> for Jwt { - type Error = RbError; + type Error = crate::errors::RbError; async fn from_request(req: &'r Request<'_>) -> Outcome { let bearer = try_outcome!(req.guard::().await).0; - let config = try_outcome!(req.guard::<&State>().await.map_failure(|_| ( - Status::InternalServerError, - RbError::Custom("Couldn't get config guard.") - ))); - let key: Hmac = match Hmac::new_from_slice(&config.jwt.key.as_bytes()) { + // Get secret & key + let secret = match std::env::var("JWT_KEY") { + Ok(key) => key, + Err(_) => { + return Outcome::Failure(( + Status::InternalServerError, + Self::Error::AuthUnauthorized, + )) + } + }; + let key: Hmac = match Hmac::new_from_slice(secret.as_bytes()) { Ok(key) => key, Err(_) => { return Outcome::Failure(( @@ -113,7 +118,7 @@ impl<'r> FromRequest<'r> for Admin if user.admin { Outcome::Success(Self(user)) } else { - Outcome::Failure((Status::Unauthorized, RbError::AuthUnauthorized)) + Outcome::Forward(()) } } } diff --git a/src/main.rs b/src/main.rs index d4ee778..fa147e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,12 +12,7 @@ use figment::{ providers::{Env, Format, Yaml}, Figment, }; -use rocket::{ - fairing::AdHoc, - http::Status, - serde::json::{json, Value}, - Build, Request, Rocket, -}; +use rocket::{fairing::AdHoc, Build, Rocket}; use rocket_sync_db_pools::database; use serde::{Deserialize, Serialize}; @@ -31,12 +26,6 @@ pub(crate) mod schema; #[database("postgres_rb")] pub struct RbDbConn(diesel::PgConnection); -#[catch(default)] -fn default_catcher(status: Status, _: &Request) -> Value -{ - json!({"status": status.code, "message": ""}) -} - embed_migrations!(); async fn run_db_migrations(rocket: Rocket) -> Result, Rocket> @@ -99,7 +88,6 @@ fn rocket() -> _ )) .attach(AdHoc::try_on_ignite("Create admin user", create_admin_user)) .attach(AdHoc::config::()) - .register("/", catchers![default_catcher]) .mount( "/api/auth", routes![auth::already_logged_in, auth::login, auth::refresh_token,], diff --git a/tests/admin.py b/tests/admin.py deleted file mode 100644 index 19a1c5b..0000000 --- a/tests/admin.py +++ /dev/null @@ -1,64 +0,0 @@ -import requests - - -class RbClient: - def __init__(self, username, password, base_url = "http://localhost:8000/api"): - self.username = username - self.password = password - self.base_url = base_url - - self.jwt = None - self.refresh_token = None - - def _login(self): - r = requests.post(f"{self.base_url}/auth/login", json={ - "username": self.username, - "password": self.password, - }) - - if r.status_code != 200: - raise Exception("Couldn't login") - - res = r.json() - self.jwt = res["token"] - self.refresh_token = res["refreshToken"] - - def _refresh(self): - r = requests.post(f"{self.base_url}/auth/refresh", json={"refreshToken": self.refresh_token}) - - if r.status_code != 200: - raise Exception("Couldn't refresh") - - res = r.json() - self.jwt = res["token"] - self.refresh_token = res["refreshToken"] - - def _request(self, type_, url, retry=2, *args, **kwargs): - if self.jwt: - headers = kwargs.get("headers", {}) - headers["Authorization"] = f"Bearer {self.jwt}" - kwargs["headers"] = headers - print(kwargs["headers"]) - - r = requests.request(type_, url, *args, **kwargs) - - if r.status_code != 200 and retry > 0: - if self.refresh_token: - self._refresh() - - else: - self._login() - - r = self._request(type_, url, *args, **kwargs, retry=retry - 1) - - return r - - def get(self, url, *args, **kwargs): - return self._request("GET", f"{self.base_url}{url}", *args, **kwargs) - - - -if __name__ == "__main__": - client = RbClient("admin", "password") - - print(client.get("/admin/users").json())