diff --git a/server/src/db/mod.rs b/server/src/db/mod.rs index 02c4284..1dbaba5 100644 --- a/server/src/db/mod.rs +++ b/server/src/db/mod.rs @@ -2,19 +2,14 @@ mod conn; pub mod entities; mod migrator; -use migrator::Migrator; -use sea_orm::ColumnTrait; -use sea_orm::ConnectOptions; -use sea_orm::Database; -use sea_orm::DatabaseConnection; -use sea_orm::EntityTrait; -use sea_orm::PaginatorTrait; -use sea_orm::QueryFilter; -use sea_orm::QueryOrder; +use sea_orm::{ + ColumnTrait, ConnectOptions, Database, DatabaseConnection, DeleteResult, EntityTrait, + InsertResult, NotSet, PaginatorTrait, QueryFilter, QueryOrder, Set, +}; use sea_orm_migration::MigratorTrait; -pub use entities::prelude::*; -pub use entities::*; +pub use entities::{prelude::*, *}; +use migrator::Migrator; type Result = std::result::Result; @@ -53,6 +48,20 @@ impl RieterDb { .await } + pub async fn insert_repo( + &self, + name: &str, + description: Option<&str>, + ) -> Result> { + let model = repo::ActiveModel { + id: NotSet, + name: Set(String::from(name)), + description: Set(description.map(String::from)), + }; + + Repo::insert(model).exec(&self.conn).await + } + pub async fn packages(&self, per_page: u64, page: u64) -> Result<(u64, Vec)> { let paginator = Package::find() .order_by_asc(package::Column::Id) @@ -66,4 +75,35 @@ impl RieterDb { pub async fn package(&self, id: i32) -> Result> { package::Entity::find_by_id(id).one(&self.conn).await } + + pub async fn package_by_fields( + &self, + repo_id: i32, + name: &str, + version: Option<&str>, + arch: &str, + ) -> Result> { + let mut query = Package::find() + .filter(package::Column::RepoId.eq(repo_id)) + .filter(package::Column::Name.eq(name)) + .filter(package::Column::Arch.eq(arch)); + + if let Some(version) = version { + query = query.filter(package::Column::Version.eq(version)); + } + + query.one(&self.conn).await + } + + pub async fn delete_packages_with_arch( + &self, + repo_id: i32, + arch: &str, + ) -> Result { + Package::delete_many() + .filter(package::Column::RepoId.eq(repo_id)) + .filter(package::Column::Arch.eq(arch)) + .exec(&self.conn) + .await + } } diff --git a/server/src/repo/mod.rs b/server/src/repo/mod.rs index 8ab2b89..f87d572 100644 --- a/server/src/repo/mod.rs +++ b/server/src/repo/mod.rs @@ -14,7 +14,7 @@ use axum::response::IntoResponse; use axum::routing::{delete, post}; use axum::Router; use futures::StreamExt; -use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, ModelTrait, QueryFilter}; +use sea_orm::{ActiveModelTrait, EntityTrait, ModelTrait}; use std::sync::Arc; use tokio::{fs, io::AsyncWriteExt}; use tower::util::ServiceExt; @@ -132,23 +132,13 @@ async fn post_package_archive( let repo_id = if let Some(repo_entity) = res { repo_entity.id } else { - let model = db::repo::ActiveModel { - name: sea_orm::Set(repo.clone()), - ..Default::default() - }; - - db::Repo::insert(model) - .exec(&global.db) - .await? - .last_insert_id + global.db.insert_repo(&repo, None).await?.last_insert_id }; // If the package already exists in the database, we remove it first - let res = db::Package::find() - .filter(db::package::Column::RepoId.eq(repo_id)) - .filter(db::package::Column::Name.eq(&pkg.info.name)) - .filter(db::package::Column::Arch.eq(&pkg.info.arch)) - .one(&global.db) + let res = global + .db + .package_by_fields(repo_id, &pkg.info.name, None, &pkg.info.arch) .await?; if let Some(entry) = res { @@ -226,11 +216,9 @@ async fn delete_arch_repo( let res = global.db.repo_by_name(&repo).await?; if let Some(repo_entry) = res { - // Also remove all packages for that architecture from database - db::Package::delete_many() - .filter(db::package::Column::RepoId.eq(repo_entry.id)) - .filter(db::package::Column::Arch.eq(&arch)) - .exec(&global.db) + global + .db + .delete_packages_with_arch(repo_entry.id, &arch) .await?; } tracing::info!("Removed architecture '{}' from repository '{}'", arch, repo); @@ -257,13 +245,14 @@ async fn delete_package( let res = global.db.repo_by_name(&repo).await?; if let Some(repo_entry) = res { - // Also remove entry from database - let res = db::Package::find() - .filter(db::package::Column::RepoId.eq(repo_entry.id)) - .filter(db::package::Column::Name.eq(name)) - .filter(db::package::Column::Version.eq(format!("{}-{}", version, release))) - .filter(db::package::Column::Arch.eq(arch)) - .one(&global.db) + let res = global + .db + .package_by_fields( + repo_entry.id, + &name, + Some(&format!("{}-{}", version, release)), + &arch, + ) .await?; if let Some(entry) = res {