From 3b33cba0d4ce6861fac098e834f698845776240c Mon Sep 17 00:00:00 2001 From: Chewing_Bever Date: Tue, 25 Jul 2023 12:45:29 +0200 Subject: [PATCH] feat: add some proper error handling --- rust-toolchain.toml | 2 -- src/api/deploy.rs | 39 ++++++++++++++++------------------ src/error.rs | 51 +++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 11 ++++++---- 4 files changed, 76 insertions(+), 27 deletions(-) delete mode 100644 rust-toolchain.toml create mode 100644 src/error.rs diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 8f357ce..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "1.69" diff --git a/src/api/deploy.rs b/src/api/deploy.rs index aa0f28d..c9131db 100644 --- a/src/api/deploy.rs +++ b/src/api/deploy.rs @@ -8,6 +8,7 @@ use axum::{ use flate2::read::GzDecoder; use futures_util::TryStreamExt; use serde::Deserialize; +use std::io; use tar::Archive; use tokio_util::io::StreamReader; @@ -22,15 +23,15 @@ pub async fn post_deploy( Extension(data_dir): Extension, Query(params): Query, res: BodyStream, -) -> impl IntoResponse { +) -> crate::Result<()> { // This converts a stream into something that implements AsyncRead, which we can then use to // asynchronously write the file to disk let mut read = StreamReader::new(res.map_err(|axum_err| std::io::Error::new(ErrorKind::Other, axum_err))); let uuid = uuid::Uuid::new_v4(); let file_path = Path::new(&data_dir).join(uuid.as_hyphenated().to_string()); - let mut file = tokio::fs::File::create(&file_path).await.unwrap(); - tokio::io::copy(&mut read, &mut file).await; + let mut file = tokio::fs::File::create(&file_path).await?; + tokio::io::copy(&mut read, &mut file).await?; // If no dir is provided, we use the default one. Otherwise, use the provided one. let static_path = Path::new(&data_dir) @@ -38,34 +39,30 @@ pub async fn post_deploy( .join(params.dir.unwrap_or(DEFAULT_STATIC_SITE.to_string())); // Make sure the static directory exists - tokio::fs::create_dir_all(&static_path).await; + tokio::fs::create_dir_all(&static_path).await?; let fp_clone = file_path.clone(); // Extract the contents of the tarball synchronously - let res = - match tokio::task::spawn_blocking(move || process_archive(&fp_clone, &static_path)).await { - Ok(_) => StatusCode::OK, - Err(_) => StatusCode::INTERNAL_SERVER_ERROR, - }; + tokio::task::spawn_blocking(move || process_archive(&fp_clone, &static_path)).await??; // Remove archive file after use - tokio::fs::remove_file(&file_path).await; + tokio::fs::remove_file(&file_path).await?; - res + Ok(()) } -fn process_archive(archive_path: &Path, static_dir: &Path) -> Result<(), ()> { - let file = std::fs::File::open(archive_path).map_err(|_| ())?; +fn process_archive(archive_path: &Path, static_dir: &Path) -> io::Result<()> { + let file = std::fs::File::open(archive_path)?; let tar = GzDecoder::new(file); let mut archive = Archive::new(tar); let mut paths = HashSet::new(); - let entries = archive.entries().map_err(|_| ())?; + let entries = archive.entries()?; // Extract each entry into the output directory - for entry_res in entries { - let mut entry = entry_res.map_err(|_| ())?; - entry.unpack_in(static_dir).map_err(|_| ())?; + for entry in entries { + let mut entry = entry?; + entry.unpack_in(static_dir)?; if let Ok(path) = entry.path() { paths.insert(path.into_owned()); @@ -76,20 +73,20 @@ fn process_archive(archive_path: &Path, static_dir: &Path) -> Result<(), ()> { let mut items = vec![]; // Start by populating the vec with the initial files - let iter = static_dir.read_dir().map_err(|_| ())?; + let iter = static_dir.read_dir()?; iter.filter_map(|r| r.ok()) .for_each(|e| items.push(e.path())); // As long as there are still items in the vec, we keep going - while items.len() > 0 { + while !items.is_empty() { let item = items.pop().unwrap(); tracing::debug!("{:?}", item); if !paths.contains(item.strip_prefix(&static_dir).unwrap()) { if item.is_dir() { - std::fs::remove_dir_all(item); + std::fs::remove_dir_all(item)?; } else { - std::fs::remove_file(item); + std::fs::remove_file(item)?; } } else if let Ok(iter) = item.read_dir() { iter.filter_map(|r| r.ok()) diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..fa58962 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,51 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use std::error::Error; +use std::fmt; +use std::io; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum ServerError { + IO(io::Error), + Axum(axum::Error), +} + +impl fmt::Display for ServerError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ServerError::IO(err) => write!(fmt, "{}", err), + ServerError::Axum(err) => write!(fmt, "{}", err), + } + } +} + +impl Error for ServerError {} + +impl IntoResponse for ServerError { + fn into_response(self) -> Response { + match self { + ServerError::IO(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + ServerError::Axum(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } +} + +impl From for ServerError { + fn from(err: io::Error) -> Self { + ServerError::IO(err) + } +} + +impl From for ServerError { + fn from(err: axum::Error) -> Self { + ServerError::Axum(err) + } +} + +impl From for ServerError { + fn from(err: tokio::task::JoinError) -> Self { + ServerError::IO(err.into()) + } +} diff --git a/src/main.rs b/src/main.rs index fdb0600..50c74bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,10 @@ +mod api; +mod error; +mod matrix; +mod metrics; + +pub use error::Result; + use std::{future::ready, net::SocketAddr}; use axum::{ @@ -12,10 +19,6 @@ use tower_http::{ }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -mod api; -mod matrix; -mod metrics; - /// Name of the directory where static sites are stored inside the data directory const STATIC_DIR_NAME: &str = "static"; /// Name of the subdir of STATIC_DIR_NAME where the default (fallback) site is located