diff --git a/src/db/repository/mod.rs b/src/db/repository/mod.rs index dd53e86..f890c88 100644 --- a/src/db/repository/mod.rs +++ b/src/db/repository/mod.rs @@ -1,5 +1,6 @@ mod auth; mod device; +mod subscription; use super::DbPool; diff --git a/src/db/repository/subscription.rs b/src/db/repository/subscription.rs new file mode 100644 index 0000000..82d9e39 --- /dev/null +++ b/src/db/repository/subscription.rs @@ -0,0 +1,258 @@ +use std::collections::HashSet; + +use diesel::prelude::*; + +use super::SqliteRepository; +use crate::{ + db::{self, schema::*}, + gpodder, +}; + +impl gpodder::SubscriptionRepository for SqliteRepository { + fn subscriptions_for_user( + &self, + user: &gpodder::User, + ) -> Result, gpodder::AuthErr> { + Ok(subscriptions::table + .inner_join(devices::table) + .filter(devices::user_id.eq(user.id)) + .select(subscriptions::url) + .distinct() + .get_results(&mut self.pool.get()?)?) + } + + fn subscriptions_for_device( + &self, + user: &gpodder::User, + device_id: &str, + ) -> Result, gpodder::AuthErr> { + Ok(subscriptions::table + .inner_join(devices::table) + .filter( + devices::user_id + .eq(user.id) + .and(devices::device_id.eq(device_id)), + ) + .select(subscriptions::url) + .get_results(&mut self.pool.get()?)?) + } + + fn set_subscriptions_for_device( + &self, + user: &gpodder::User, + device_id: &str, + urls: Vec, + ) -> Result { + // TODO use a better timestamp + let timestamp = chrono::Utc::now().timestamp_millis(); + + self.pool.get()?.transaction(|conn| { + let device = devices::table + .select(db::Device::as_select()) + .filter( + devices::user_id + .eq(user.id) + .and(devices::device_id.eq(device_id)), + ) + .get_result(conn)?; + + // https://github.com/diesel-rs/diesel/discussions/2826 + // SQLite doesn't support default on conflict set values, so we can't handle this using + // on conflict. Therefore, we instead calculate which URLs should be inserted and which + // updated, so we avoid conflicts. + let urls: HashSet = urls.into_iter().collect(); + let urls_in_db: HashSet = subscriptions::table + .select(subscriptions::url) + .filter(subscriptions::device_id.eq(device.id)) + .get_results(conn)? + .into_iter() + .collect(); + + // URLs originally in the database that are no longer in the list + let urls_to_delete = urls_in_db.difference(&urls); + + // URLs not in the database that are in the new list + let urls_to_insert = urls.difference(&urls_in_db); + + // URLs that are in both the database and the new list. For these, those marked as + // "deleted" in the database are updated so they're no longer deleted, with their + // timestamp updated. + let urls_to_update = urls.intersection(&urls_in_db); + + // Mark the URLs to delete as properly deleted + diesel::update( + subscriptions::table.filter( + subscriptions::device_id + .eq(device.id) + .and(subscriptions::url.eq_any(urls_to_delete)), + ), + ) + .set(( + subscriptions::deleted.eq(true), + subscriptions::time_changed.eq(timestamp), + )) + .execute(conn)?; + + // Update the existing deleted URLs that are reinserted as no longer deleted + diesel::update( + subscriptions::table.filter( + subscriptions::device_id + .eq(device.id) + .and(subscriptions::url.eq_any(urls_to_update)) + .and(subscriptions::deleted.eq(true)), + ), + ) + .set(( + subscriptions::deleted.eq(false), + subscriptions::time_changed.eq(timestamp), + )) + .execute(conn)?; + + // Insert the new values into the database + diesel::insert_into(subscriptions::table) + .values( + urls_to_insert + .into_iter() + .map(|url| db::NewSubscription { + device_id: device.id, + url: url.to_string(), + deleted: false, + time_changed: timestamp, + }) + .collect::>(), + ) + .execute(conn)?; + + Ok::<_, diesel::result::Error>(()) + })?; + + Ok(timestamp + 1) + } + + fn update_subscriptions_for_device( + &self, + user: &gpodder::User, + device_id: &str, + add: Vec, + remove: Vec, + ) -> Result { + // TODO use a better timestamp + let timestamp = chrono::Utc::now().timestamp_millis(); + + // TODO URLs that are in both the added and removed lists will currently get "re-added", + // meaning their change timestamp will be updated even though they haven't really changed. + let add: HashSet<_> = add.into_iter().collect(); + let remove: HashSet<_> = remove.into_iter().collect(); + + self.pool.get()?.transaction(|conn| { + let device = devices::table + .select(db::Device::as_select()) + .filter( + devices::user_id + .eq(user.id) + .and(devices::device_id.eq(device_id)), + ) + .get_result(conn)?; + + let urls_in_db: HashSet = subscriptions::table + .select(subscriptions::url) + .filter(subscriptions::device_id.eq(device.id)) + .get_results(conn)? + .into_iter() + .collect(); + + // Subscriptions to remove are those that were already in the database and are now part + // of the removed list. Subscriptions that were never added in the first place don't + // need to be marked as deleted. We also only update those that aren't already marked + // as deleted. + let urls_to_delete = remove.intersection(&urls_in_db); + + diesel::update( + subscriptions::table.filter( + subscriptions::device_id + .eq(device.id) + .and(subscriptions::url.eq_any(urls_to_delete)) + .and(subscriptions::deleted.eq(false)), + ), + ) + .set(( + subscriptions::deleted.eq(true), + subscriptions::time_changed.eq(timestamp), + )) + .execute(conn)?; + + // Subscriptions to update are those that are already in the database, but are also in + // the added list. Only those who were originally marked as deleted get updated. + let urls_to_update = add.intersection(&urls_in_db); + + diesel::update( + subscriptions::table.filter( + subscriptions::device_id + .eq(device.id) + .and(subscriptions::url.eq_any(urls_to_update)) + .and(subscriptions::deleted.eq(true)), + ), + ) + .set(( + subscriptions::deleted.eq(false), + subscriptions::time_changed.eq(timestamp), + )) + .execute(conn)?; + + // Subscriptions to insert are those that aren't in the database and are part of the + // added list + let urls_to_insert = add.difference(&urls_in_db); + + diesel::insert_into(subscriptions::table) + .values( + urls_to_insert + .into_iter() + .map(|url| db::NewSubscription { + device_id: device.id, + url: url.to_string(), + deleted: false, + time_changed: timestamp, + }) + .collect::>(), + ) + .execute(conn)?; + + Ok::<_, diesel::result::Error>(()) + })?; + + Ok(timestamp + 1) + } + + fn subscription_updates_for_device( + &self, + user: &gpodder::User, + device_id: &str, + since: i64, + ) -> Result<(i64, Vec, Vec), gpodder::AuthErr> { + let (mut timestamp, mut added, mut removed) = (0, Vec::new(), Vec::new()); + + let query = subscriptions::table + .inner_join(devices::table) + .filter( + devices::user_id + .eq(user.id) + .and(devices::device_id.eq(device_id)) + .and(subscriptions::time_changed.ge(since)), + ) + .select(db::Subscription::as_select()); + + for sub in query.load_iter(&mut self.pool.get()?)? { + let sub = sub?; + + if sub.deleted { + removed.push(sub.url); + } else { + added.push(sub.url); + } + + timestamp = timestamp.max(sub.time_changed); + } + + Ok((timestamp + 1, added, removed)) + } +} diff --git a/src/gpodder/mod.rs b/src/gpodder/mod.rs index 205bb69..9a887e7 100644 --- a/src/gpodder/mod.rs +++ b/src/gpodder/mod.rs @@ -30,7 +30,44 @@ pub trait DeviceRepository: Send + Sync { fn update_device_info( &self, user: &User, - device: &str, + device_id: &str, patch: DevicePatch, ) -> Result<(), AuthErr>; } + +pub trait SubscriptionRepository: Send + Sync { + /// Return the subscriptions for the given device + fn subscriptions_for_device( + &self, + user: &User, + device_id: &str, + ) -> Result, AuthErr>; + + /// Return all subscriptions for a given user + fn subscriptions_for_user(&self, user: &User) -> Result, AuthErr>; + + /// Replace the list of subscriptions for a device with the given list + fn set_subscriptions_for_device( + &self, + user: &User, + device_id: &str, + urls: Vec, + ) -> Result; + + /// Update the list of subscriptions for a device by adding and removing the given URLs + fn update_subscriptions_for_device( + &self, + user: &User, + device_id: &str, + add: Vec, + remove: Vec, + ) -> Result; + + /// Returns the changes in subscriptions since the given timestamp. + fn subscription_updates_for_device( + &self, + user: &User, + device_id: &str, + since: i64, + ) -> Result<(i64, Vec, Vec), AuthErr>; +} diff --git a/src/server/gpodder/advanced/subscriptions.rs b/src/server/gpodder/advanced/subscriptions.rs index 8969c23..6563ebe 100644 --- a/src/server/gpodder/advanced/subscriptions.rs +++ b/src/server/gpodder/advanced/subscriptions.rs @@ -7,16 +7,13 @@ use axum::{ use serde::Deserialize; use crate::{ - db, + gpodder::{self, SubscriptionRepository}, server::{ error::{AppError, AppResult}, gpodder::{ auth_middleware, format::{Format, StringWithFormat}, - models::{ - DeviceType, SubscriptionChangeResponse, SubscriptionDelta, - SubscriptionDeltaResponse, - }, + models::{SubscriptionChangeResponse, SubscriptionDelta, SubscriptionDeltaResponse}, }, Context, }, @@ -34,7 +31,7 @@ pub fn router(ctx: Context) -> Router { pub async fn post_subscription_changes( State(ctx): State, Path((username, id)): Path<(String, StringWithFormat)>, - Extension(user): Extension, + Extension(user): Extension, Json(delta): Json, ) -> AppResult> { if id.format != Format::Json { @@ -45,37 +42,18 @@ pub async fn post_subscription_changes( return Err(AppError::BadRequest); } - let timestamp = chrono::Utc::now().timestamp_millis(); - - tokio::task::spawn_blocking(move || { - let device = if let Some(device) = db::Device::by_device_id(&ctx.pool, user.id, &id)? { - device - } else { - db::NewDevice::new( - user.id, - id.to_string(), - String::new(), - DeviceType::Other.into(), - ) - .insert(&ctx.pool)? - }; - - db::Subscription::update_for_device( - &ctx.pool, - device.id, - delta.add, - delta.remove, - timestamp, - ) + Ok(tokio::task::spawn_blocking(move || { + ctx.repo + .update_subscriptions_for_device(&user, &id, delta.add, delta.remove) }) .await - .unwrap()?; - - Ok(Json(SubscriptionChangeResponse { - timestamp: timestamp + 1, - // TODO implement URL sanitization - update_urls: vec![], - })) + .unwrap() + .map(|timestamp| { + Json(SubscriptionChangeResponse { + timestamp, + update_urls: Vec::new(), + }) + })?) } #[derive(Deserialize)] @@ -87,7 +65,7 @@ pub struct SinceQuery { pub async fn get_subscription_changes( State(ctx): State, Path((username, id)): Path<(String, StringWithFormat)>, - Extension(user): Extension, + Extension(user): Extension, Query(query): Query, ) -> AppResult> { if id.format != Format::Json { @@ -98,34 +76,17 @@ pub async fn get_subscription_changes( return Err(AppError::BadRequest); } - let subscriptions = tokio::task::spawn_blocking(move || { - let device = - db::Device::by_device_id(&ctx.pool, user.id, &id)?.ok_or(AppError::NotFound)?; - - Ok::<_, AppError>(db::Subscription::updated_since_for_device( - &ctx.pool, - device.id, - query.since, - )?) + Ok(tokio::task::spawn_blocking(move || { + ctx.repo + .subscription_updates_for_device(&user, &id, query.since) }) .await - .unwrap()?; - - let mut delta = SubscriptionDeltaResponse::default(); - delta.timestamp = query.since; - - for sub in subscriptions.into_iter() { - if sub.deleted { - delta.remove.push(sub.url); - } else { - delta.add.push(sub.url); - } - - delta.timestamp = delta.timestamp.max(sub.time_changed); - } - - // Timestamp should reflect the events *after* the last seen change - delta.timestamp += 1; - - Ok(Json(delta)) + .unwrap() + .map(|(timestamp, add, remove)| { + Json(SubscriptionDeltaResponse { + add, + remove, + timestamp, + }) + })?) } diff --git a/src/server/gpodder/simple/subscriptions.rs b/src/server/gpodder/simple/subscriptions.rs index 47ed072..4e6f7c6 100644 --- a/src/server/gpodder/simple/subscriptions.rs +++ b/src/server/gpodder/simple/subscriptions.rs @@ -2,14 +2,14 @@ use axum::{ extract::{Path, State}, middleware, routing::get, - Extension, Form, Json, Router, + Extension, Json, Router, }; use crate::{ - db, + gpodder::{self, SubscriptionRepository}, server::{ error::{AppError, AppResult}, - gpodder::{auth_middleware, format::StringWithFormat, models::DeviceType}, + gpodder::{auth_middleware, format::StringWithFormat}, Context, }, }; @@ -27,73 +27,53 @@ pub fn router(ctx: Context) -> Router { pub async fn get_device_subscriptions( State(ctx): State, Path((username, id)): Path<(String, StringWithFormat)>, - Extension(user): Extension, + Extension(user): Extension, ) -> AppResult>> { if username != user.username { return Err(AppError::BadRequest); } - let subscriptions = tokio::task::spawn_blocking(move || { - let device = - db::Device::by_device_id(&ctx.pool, user.id, &id)?.ok_or(AppError::NotFound)?; - - Ok::<_, AppError>(db::Subscription::for_device(&ctx.pool, device.id)?) - }) - .await - .unwrap()?; - - Ok(Json(subscriptions)) + Ok( + tokio::task::spawn_blocking(move || ctx.repo.subscriptions_for_device(&user, &id)) + .await + .unwrap() + .map(Json)?, + ) } pub async fn get_user_subscriptions( State(ctx): State, Path(username): Path, - Extension(user): Extension, + Extension(user): Extension, ) -> AppResult>> { if *username != user.username { return Err(AppError::BadRequest); } - let subscriptions = - tokio::task::spawn_blocking(move || db::Subscription::for_user(&ctx.pool, user.id)) + Ok( + tokio::task::spawn_blocking(move || ctx.repo.subscriptions_for_user(&user)) .await - .unwrap()?; - - Ok(Json(subscriptions)) + .unwrap() + .map(Json)?, + ) } pub async fn put_device_subscriptions( State(ctx): State, Path((username, id)): Path<(String, StringWithFormat)>, - Extension(user): Extension, + Extension(user): Extension, Json(urls): Json>, ) -> AppResult<()> { if *username != user.username { return Err(AppError::BadRequest); } - tokio::task::spawn_blocking(move || { - let device = if let Some(device) = db::Device::by_device_id(&ctx.pool, user.id, &id)? { - device - } else { - db::NewDevice::new( - user.id, - id.to_string(), - String::new(), - DeviceType::Other.into(), - ) - .insert(&ctx.pool)? - }; - - Ok::<_, AppError>(db::Subscription::set_for_device( - &ctx.pool, - device.id, - urls, - chrono::Utc::now().timestamp(), - )?) - }) - .await - .unwrap()?; - - Ok(()) + Ok( + tokio::task::spawn_blocking(move || { + ctx.repo.set_subscriptions_for_device(&user, &id, urls) + }) + .await + .unwrap() + .map(|_| ())?, + ) }