From 3c4af12fa13c8e37940616b6e4be230f7b3c3989 Mon Sep 17 00:00:00 2001 From: Jef Roosens Date: Mon, 24 Feb 2025 13:24:23 +0100 Subject: [PATCH] feat: implement custom deserializer for path segments with format extension --- src/server/gpodder/advanced/devices.rs | 26 +++++------ src/server/gpodder/format.rs | 65 ++++++++++++++++++++++++++ src/server/gpodder/mod.rs | 1 + 3 files changed, 79 insertions(+), 13 deletions(-) create mode 100644 src/server/gpodder/format.rs diff --git a/src/server/gpodder/advanced/devices.rs b/src/server/gpodder/advanced/devices.rs index 7be52ef..7df6ec1 100644 --- a/src/server/gpodder/advanced/devices.rs +++ b/src/server/gpodder/advanced/devices.rs @@ -6,11 +6,12 @@ use axum::{ }; use crate::{ - db::{self, User}, + db, server::{ error::{AppError, AppResult}, gpodder::{ auth_middleware, + format::{Format, StringWithFormat}, models::{Device, DevicePatch, DeviceType}, }, Context, @@ -26,14 +27,14 @@ pub fn router(ctx: Context) -> Router { async fn get_devices( State(ctx): State, - Path(username): Path, - Extension(user): Extension, + Path(username): Path, + Extension(user): Extension, ) -> AppResult>> { - // Check suffix is present and return 404 otherwise; axum doesn't support matching part of a - // route segment - let username = username.strip_suffix(".json").ok_or(AppError::NotFound)?; + if username.format != Format::Json { + return Err(AppError::NotFound); + } - if username != user.username { + if *username != user.username { return Err(AppError::BadRequest); } @@ -55,14 +56,13 @@ async fn get_devices( async fn post_device( State(ctx): State, - Path((_username, id)): Path<(String, String)>, - Extension(user): Extension, + Path((_username, id)): Path<(String, StringWithFormat)>, + Extension(user): Extension, Json(patch): Json, ) -> AppResult<()> { - let id = id - .strip_suffix(".json") - .ok_or(AppError::NotFound)? - .to_string(); + if id.format != Format::Json { + return Err(AppError::NotFound); + } tokio::task::spawn_blocking(move || { if let Some(mut device) = db::Device::by_device_id(&ctx.pool, user.id, &id)? { diff --git a/src/server/gpodder/format.rs b/src/server/gpodder/format.rs new file mode 100644 index 0000000..4c2a5de --- /dev/null +++ b/src/server/gpodder/format.rs @@ -0,0 +1,65 @@ +use std::ops::Deref; + +use serde::{ + de::{value::StrDeserializer, Visitor}, + Deserialize, +}; + +#[derive(Deserialize, Debug, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum Format { + Json, + OPML, + Plaintext, +} + +#[derive(Debug)] +pub struct StringWithFormat { + pub s: String, + pub format: Format, +} + +impl Deref for StringWithFormat { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.s + } +} + +impl<'de> Deserialize<'de> for StringWithFormat { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct StrVisitor; + + impl<'de> Visitor<'de> for StrVisitor { + type Value = StringWithFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str( + "`text.ext` format, with `ext` being one of `json`, `opml` or `plaintext`", + ) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + if let Some((text, ext)) = v.rsplit_once('.') { + let format = Format::deserialize(StrDeserializer::new(ext))?; + + Ok(StringWithFormat { + s: text.to_string(), + format, + }) + } else { + Err(E::custom(format!("invalid format '{}'", v))) + } + } + } + + deserializer.deserialize_str(StrVisitor) + } +} diff --git a/src/server/gpodder/mod.rs b/src/server/gpodder/mod.rs index 7b2ba67..00f1d6b 100644 --- a/src/server/gpodder/mod.rs +++ b/src/server/gpodder/mod.rs @@ -1,4 +1,5 @@ mod advanced; +mod format; mod models; mod simple;