From 0d4d96d7614af9920e4800fd2e1de13bd72f9bcb Mon Sep 17 00:00:00 2001 From: Jef Roosens Date: Sat, 21 Aug 2021 18:05:16 +0200 Subject: [PATCH] Added very basic admin user creation --- src/rb/auth.rs | 75 +++++++++++++++++++++++++++++++++++++----------- src/rb/errors.rs | 12 ++++++-- src/rb/models.rs | 17 +++++++---- src/rbs/auth.rs | 2 +- src/rbs/main.rs | 23 +++++++++++++++ 5 files changed, 104 insertions(+), 25 deletions(-) diff --git a/src/rb/auth.rs b/src/rb/auth.rs index 11ed395..3d70159 100644 --- a/src/rb/auth.rs +++ b/src/rb/auth.rs @@ -1,17 +1,17 @@ use crate::errors::RBError; -use crate::models::{User, NewRefreshToken}; -use crate::schema::users::dsl as users; +use crate::models::{NewRefreshToken, User, NewUser}; use crate::schema::refresh_tokens::dsl as refresh_tokens; +use crate::schema::users::dsl as users; use argon2::verify_encoded; +use chrono::Utc; use diesel::prelude::*; -use diesel::{PgConnection, insert_into}; +use diesel::{insert_into, PgConnection}; use hmac::{Hmac, NewMac}; use jwt::SignWithKey; +use rand::{thread_rng, Rng}; +use serde::Serialize; use sha2::Sha256; use std::collections::HashMap; -use chrono::Utc; -use serde::Serialize; -use rand::{thread_rng, Rng}; /// Expire time for the JWT tokens in seconds. const JWT_EXP_SECONDS: i64 = 900; @@ -25,6 +25,11 @@ pub fn verify_user(conn: &PgConnection, username: &str, password: &str) -> crate .first::(conn) .map_err(|_| RBError::UnknownUser)?; + // Check if a user is blocked + if user.blocked { + return Err(RBError::BlockedUser); + } + match verify_encoded(user.password.as_str(), password.as_bytes()) { Ok(true) => Ok(user), _ => Err(RBError::InvalidPassword), @@ -35,35 +40,73 @@ pub fn verify_user(conn: &PgConnection, username: &str, password: &str) -> crate #[serde(rename_all = "camelCase")] pub struct JWTResponse { token: String, - refresh_token: String + refresh_token: String, } pub fn generate_jwt_token(conn: &PgConnection, user: &User) -> crate::Result { // TODO actually use proper secret here - let key: Hmac = Hmac::new_from_slice(b"some-secret").map_err(|_| RBError::JWTCreationError)?; + let key: Hmac = + Hmac::new_from_slice(b"some-secret").map_err(|_| RBError::JWTCreationError)?; // Create the claims let mut claims = HashMap::new(); claims.insert("id", user.id.to_string()); claims.insert("username", user.username.clone()); claims.insert("admin", user.admin.to_string()); - claims.insert("exp", (Utc::now().timestamp() + JWT_EXP_SECONDS).to_string()); + claims.insert( + "exp", + (Utc::now().timestamp() + JWT_EXP_SECONDS).to_string(), + ); // Sign the claims into a new token - let token = claims.sign_with_key(&key).map_err(|_| RBError::JWTCreationError)?; + let token = claims + .sign_with_key(&key) + .map_err(|_| RBError::JWTCreationError)?; // Generate a random refresh token let mut refresh_token = [0u8; REFRESH_TOKEN_N_BYTES]; thread_rng().fill(&mut refresh_token[..]); // Store refresh token in database - insert_into(refresh_tokens::refresh_tokens).values(NewRefreshToken { - token: refresh_token.to_vec(), - user_id: user.id - }).execute(conn).map_err(|_| RBError::JWTCreationError)?; + insert_into(refresh_tokens::refresh_tokens) + .values(NewRefreshToken { + token: refresh_token.to_vec(), + user_id: user.id, + }) + .execute(conn) + .map_err(|_| RBError::JWTCreationError)?; Ok(JWTResponse { - token: token, - refresh_token: base64::encode(refresh_token) + token, + refresh_token: base64::encode(refresh_token), }) } + +pub fn hash_password(password: &str) -> crate::Result { + // Generate a random salt + let mut salt = [0u8; 64]; + thread_rng().fill(&mut salt[..]); + + // Encode the actual password + let config = argon2::Config::default(); + argon2::hash_encoded(password.as_bytes(), &salt, &config).map_err(|_| RBError::PWSaltError) +} + +pub fn create_admin_user(conn: &PgConnection, username: &str, password: &str) -> crate::Result { + let pass_hashed = hash_password(password)?; + println!("{}", pass_hashed); + let new_user = NewUser { + username: username.to_string(), + password: pass_hashed, + admin: true, + }; + + insert_into(users::users) + .values(&new_user) + // .on_conflict((users::username, users::password, users::admin)) + // .do_update() + // .set(&new_user) + .execute(conn).map_err(|_| RBError::AdminCreationError)?; + + Ok(true) +} diff --git a/src/rb/errors.rs b/src/rb/errors.rs index adfe40b..96118e6 100644 --- a/src/rb/errors.rs +++ b/src/rb/errors.rs @@ -1,11 +1,13 @@ -use rocket::request::Request; -use rocket::response::{self, Response, Responder}; use rocket::http::Status; +use rocket::request::Request; +use rocket::response::{self, Responder, Response}; use std::io; +#[derive(Debug)] pub enum RBError { /// When the login requests an unknown user UnknownUser, + BlockedUser, /// Invalid login password. InvalidPassword, /// When a non-admin user tries to use an admin endpoint @@ -13,17 +15,21 @@ pub enum RBError { /// When an expired JWT token is used for auth. JWTTokenExpired, /// Umbrella error for when something goes wrong whilst creating a JWT token pair - JWTCreationError + JWTCreationError, + PWSaltError, + AdminCreationError, } impl<'r> Responder<'r, 'static> for RBError { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { let (status, message): (Status, &str) = match self { RBError::UnknownUser => (Status::NotFound, "Unknown user"), + RBError::BlockedUser => (Status::Unauthorized, "This user is blocked"), RBError::InvalidPassword => (Status::Unauthorized, "Invalid password"), RBError::Unauthorized => (Status::Unauthorized, "Unauthorized"), RBError::JWTTokenExpired => (Status::Unauthorized, "Token expired"), RBError::JWTCreationError => (Status::InternalServerError, "Failed to create tokens."), + _ => (Status::InternalServerError, "Internal server error") }; let mut res = Response::new(); diff --git a/src/rb/models.rs b/src/rb/models.rs index 1143e9b..858cab4 100644 --- a/src/rb/models.rs +++ b/src/rb/models.rs @@ -1,7 +1,7 @@ -use diesel::{Queryable, Insertable}; -use uuid::Uuid; +use crate::schema::{refresh_tokens, users}; +use diesel::{Insertable, Queryable, AsChangeset}; use serde::Serialize; -use crate::schema::refresh_tokens; +use uuid::Uuid; #[derive(Queryable, Serialize)] pub struct User { @@ -10,14 +10,21 @@ pub struct User { #[serde(skip_serializing)] pub password: String, #[serde(skip_serializing)] - blocked: bool, + pub blocked: bool, pub admin: bool, } +#[derive(Insertable, AsChangeset)] +#[table_name = "users"] +pub struct NewUser { + pub username: String, + pub password: String, + pub admin: bool, +} #[derive(Insertable)] #[table_name = "refresh_tokens"] pub struct NewRefreshToken { pub token: Vec, - pub user_id: Uuid + pub user_id: Uuid, } diff --git a/src/rbs/auth.rs b/src/rbs/auth.rs index 0de053a..39e3b88 100644 --- a/src/rbs/auth.rs +++ b/src/rbs/auth.rs @@ -1,5 +1,5 @@ use crate::RbDbConn; -use rb::auth::{verify_user, JWTResponse, generate_jwt_token}; +use rb::auth::{generate_jwt_token, verify_user, JWTResponse}; use rocket::serde::json::Json; use serde::Deserialize; diff --git a/src/rbs/main.rs b/src/rbs/main.rs index 0acc0a7..1792d9d 100644 --- a/src/rbs/main.rs +++ b/src/rbs/main.rs @@ -28,6 +28,25 @@ async fn run_db_migrations(rocket: Rocket) -> Result, Rocke .await } +async fn create_admin_user(rocket: Rocket) -> Result, Rocket> { + // In debug mode, the admin user is just a test user + let (admin_user, admin_password): (String, String); + + // if rocket.config().profile == "debug" { + admin_user = String::from("test"); + admin_password = String::from("test"); + // }else{ + // admin_user = std::env::var("ADMIN_USER").expect("no admin user provided"); + // admin_password = std::env::var("ADMIN_PASSWORD").expect("no admin password provided"); + // } + let conn = RbDbConn::get_one(&rocket) + .await + .expect("database connection"); + conn.run(move |c| rb::auth::create_admin_user(c, &admin_user, &admin_password).expect("failed to create admin user")).await; + + Ok(rocket) +} + #[launch] fn rocket() -> _ { rocket::build() @@ -36,5 +55,9 @@ fn rocket() -> _ { "Run database migrations", run_db_migrations, )) + .attach(AdHoc::try_on_ignite( + "Create admin user", + create_admin_user + )) .mount("/auth", auth::routes()) }