diff --git a/common/src/db/mod.rs b/common/src/db/mod.rs index f86e5387c..c125a386e 100644 --- a/common/src/db/mod.rs +++ b/common/src/db/mod.rs @@ -11,97 +11,13 @@ pub use func::*; use anyhow::{ensure, Context}; use migration::{Migrator, MigratorTrait}; use sea_orm::{ - prelude::async_trait, ConnectOptions, ConnectionTrait, DatabaseConnection, DatabaseTransaction, - DbBackend, DbErr, ExecResult, QueryResult, RuntimeErr, Statement, + prelude::async_trait, ConnectOptions, ConnectionTrait, DatabaseConnection, DbBackend, DbErr, + ExecResult, QueryResult, RuntimeErr, Statement, }; use sqlx::error::ErrorKind; use std::ops::{Deref, DerefMut}; use tracing::instrument; -pub enum Transactional { - None, - Some(DatabaseTransaction), -} - -impl Transactional { - /// Commit the database transaction. - /// - /// If there's no underlying database transaction, then this becomes a no-op. - #[instrument(skip_all, fields(transactional=matches!(self, Transactional::Some(_))), ret)] - pub async fn commit(self) -> Result<(), DbErr> { - match self { - Transactional::None => {} - Transactional::Some(inner) => { - inner.commit().await?; - } - } - - Ok(()) - } -} - -impl AsRef for Transactional { - fn as_ref(&self) -> &Transactional { - self - } -} - -impl AsRef for () { - fn as_ref(&self) -> &Transactional { - &Transactional::None - } -} - -#[derive(Clone)] -pub enum ConnectionOrTransaction<'db> { - Connection(&'db DatabaseConnection), - Transaction(&'db DatabaseTransaction), -} - -impl<'db> From<&'db DatabaseTransaction> for ConnectionOrTransaction<'db> { - fn from(value: &'db DatabaseTransaction) -> Self { - Self::Transaction(value) - } -} - -#[async_trait::async_trait] -impl ConnectionTrait for ConnectionOrTransaction<'_> { - fn get_database_backend(&self) -> DbBackend { - match self { - ConnectionOrTransaction::Connection(inner) => inner.get_database_backend(), - ConnectionOrTransaction::Transaction(inner) => inner.get_database_backend(), - } - } - - async fn execute(&self, stmt: Statement) -> Result { - match self { - ConnectionOrTransaction::Connection(inner) => inner.execute(stmt).await, - ConnectionOrTransaction::Transaction(inner) => inner.execute(stmt).await, - } - } - - async fn execute_unprepared(&self, sql: &str) -> Result { - match self { - ConnectionOrTransaction::Connection(inner) => inner.execute_unprepared(sql).await, - ConnectionOrTransaction::Transaction(inner) => inner.execute_unprepared(sql).await, - } - } - - async fn query_one(&self, stmt: Statement) -> Result, DbErr> { - match self { - ConnectionOrTransaction::Connection(inner) => inner.query_one(stmt).await, - ConnectionOrTransaction::Transaction(inner) => inner.query_one(stmt).await, - } - } - - async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - match self { - ConnectionOrTransaction::Connection(inner) => inner.query_all(stmt).await, - ConnectionOrTransaction::Transaction(inner) => inner.query_all(stmt).await, - } - } -} - #[derive(Clone, Debug)] pub struct Database { /// the database connection @@ -111,16 +27,6 @@ pub struct Database { } impl Database { - pub fn connection<'db, TX: AsRef>( - &'db self, - tx: &'db TX, - ) -> ConnectionOrTransaction<'db> { - match tx.as_ref() { - Transactional::None => ConnectionOrTransaction::Connection(&self.db), - Transactional::Some(tx) => ConnectionOrTransaction::Transaction(tx), - } - } - #[instrument(err)] pub async fn new(database: &crate::config::Database) -> Result { let url = database.to_url(); @@ -260,6 +166,37 @@ impl ConnectionTrait for Database { } } +/// Implementation of the connection trait for our database struct. +/// +/// **NOTE**: We lack the implementations for the `mock` feature. However, the mock feature would +/// require us to have the `Database` struct to be non-clone, which we don't support anyway. +#[async_trait::async_trait] +impl ConnectionTrait for &Database { + fn get_database_backend(&self) -> DbBackend { + self.db.get_database_backend() + } + + async fn execute(&self, stmt: Statement) -> Result { + self.db.execute(stmt).await + } + + async fn execute_unprepared(&self, sql: &str) -> Result { + self.db.execute_unprepared(sql).await + } + + async fn query_one(&self, stmt: Statement) -> Result, DbErr> { + self.db.query_one(stmt).await + } + + async fn query_all(&self, stmt: Statement) -> Result, DbErr> { + self.db.query_all(stmt).await + } + + fn support_returning(&self) -> bool { + self.db.support_returning() + } +} + /// A trait to help working with database errors pub trait DatabaseErrors { /// return `true` if the error is a duplicate key error diff --git a/entity/src/advisory.rs b/entity/src/advisory.rs index 5b5e03ee5..a09a454a2 100644 --- a/entity/src/advisory.rs +++ b/entity/src/advisory.rs @@ -38,7 +38,7 @@ impl Model { let db = ctx.data::>()?; if let Some(found) = self .find_related(organization::Entity) - .one(&db.connection(&db::Transactional::None)) + .one(db.as_ref()) .await? { Ok(found) @@ -51,7 +51,7 @@ impl Model { let db = ctx.data::>()?; Ok(self .find_related(vulnerability::Entity) - .all(&db.connection(&db::Transactional::None)) + .all(db.as_ref()) .await?) } } diff --git a/modules/analysis/src/endpoints.rs b/modules/analysis/src/endpoints.rs index aa93110b7..70a0544cb 100644 --- a/modules/analysis/src/endpoints.rs +++ b/modules/analysis/src/endpoints.rs @@ -13,10 +13,11 @@ use trustify_auth::{ use trustify_common::{db::query::Query, db::Database, model::Paginated, purl::Purl}; pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) { - let analysis = AnalysisService::new(db); + let analysis = AnalysisService::new(); config .app_data(web::Data::new(analysis)) + .app_data(web::Data::new(db)) .service(search_component_root_components) .service(get_component_root_components) .service(analysis_status) @@ -34,12 +35,13 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d #[get("/v1/analysis/status")] pub async fn analysis_status( service: web::Data, + db: web::Data, user: UserInformation, authorizer: web::Data, _: Require, ) -> actix_web::Result { authorizer.require(&user, Permission::ReadSbom)?; - Ok(HttpResponse::Ok().json(service.status(()).await?)) + Ok(HttpResponse::Ok().json(service.status(db.as_ref()).await?)) } #[utoipa::path( @@ -56,13 +58,14 @@ pub async fn analysis_status( #[get("/v1/analysis/root-component")] pub async fn search_component_root_components( service: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { Ok(HttpResponse::Ok().json( service - .retrieve_root_components(search, paginated, ()) + .retrieve_root_components(search, paginated, db.as_ref()) .await?, )) } @@ -80,6 +83,7 @@ pub async fn search_component_root_components( #[get("/v1/analysis/root-component/{key}")] pub async fn get_component_root_components( service: web::Data, + db: web::Data, key: web::Path, web::Query(paginated): web::Query, _: Require, @@ -88,13 +92,13 @@ pub async fn get_component_root_components( let purl: Purl = Purl::from_str(&key).map_err(Error::Purl)?; Ok(HttpResponse::Ok().json( service - .retrieve_root_components_by_purl(purl, paginated, ()) + .retrieve_root_components_by_purl(purl, paginated, db.as_ref()) .await?, )) } else { Ok(HttpResponse::Ok().json( service - .retrieve_root_components_by_name(key.to_string(), paginated, ()) + .retrieve_root_components_by_name(key.to_string(), paginated, db.as_ref()) .await?, )) } @@ -114,11 +118,16 @@ pub async fn get_component_root_components( #[get("/v1/analysis/dep")] pub async fn search_component_deps( service: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { - Ok(HttpResponse::Ok().json(service.retrieve_deps(search, paginated, ()).await?)) + Ok(HttpResponse::Ok().json( + service + .retrieve_deps(search, paginated, db.as_ref()) + .await?, + )) } #[utoipa::path( @@ -134,17 +143,22 @@ pub async fn search_component_deps( #[get("/v1/analysis/dep/{key}")] pub async fn get_component_deps( service: web::Data, + db: web::Data, key: web::Path, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { if key.starts_with("pkg:") { let purl: Purl = Purl::from_str(&key).map_err(Error::Purl)?; - Ok(HttpResponse::Ok().json(service.retrieve_deps_by_purl(purl, paginated, ()).await?)) + Ok(HttpResponse::Ok().json( + service + .retrieve_deps_by_purl(purl, paginated, db.as_ref()) + .await?, + )) } else { Ok(HttpResponse::Ok().json( service - .retrieve_deps_by_name(key.to_string(), paginated, ()) + .retrieve_deps_by_name(key.to_string(), paginated, db.as_ref()) .await?, )) } diff --git a/modules/analysis/src/service.rs b/modules/analysis/src/service.rs index 3c500334a..90268e930 100644 --- a/modules/analysis/src/service.rs +++ b/modules/analysis/src/service.rs @@ -1,47 +1,42 @@ -// use crate::Error; +use crate::{ + model::{AnalysisStatus, AncNode, AncestorSummary, DepNode, DepSummary, GraphMap, PackageNode}, + Error, +}; +use petgraph::{ + algo::is_cyclic_directed, + graph::{Graph, NodeIndex}, + visit::{NodeIndexable, VisitMap, Visitable}, + Direction, +}; use sea_orm::{ prelude::ConnectionTrait, ColumnTrait, DatabaseBackend, DbErr, EntityOrSelect, EntityTrait, QueryFilter, QueryOrder, QueryResult, QuerySelect, QueryTrait, Statement, }; -use std::collections::{HashMap, HashSet}; +use sea_query::Order; +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, +}; use tracing::instrument; use trustify_common::{ - db::{ - query::{Query, Value}, - Database, Transactional, - }, + db::query::{Filtering, Query, Value}, model::{Paginated, PaginatedResults}, + purl::Purl, }; - -use crate::model::{ - AnalysisStatus, AncNode, AncestorSummary, DepNode, DepSummary, GraphMap, PackageNode, -}; -use crate::Error; -use petgraph::algo::is_cyclic_directed; -use petgraph::graph::{Graph, NodeIndex}; -use petgraph::visit::{NodeIndexable, VisitMap, Visitable}; -use petgraph::Direction; -use sea_query::Order; -use std::str::FromStr; -use trustify_common::db::query::Filtering; -use trustify_common::db::ConnectionOrTransaction; -use trustify_common::purl::Purl; -use trustify_entity::relationship::Relationship; -use trustify_entity::{sbom, sbom_node}; +use trustify_entity::{relationship::Relationship, sbom, sbom_node}; use uuid::Uuid; -pub struct AnalysisService { - db: Database, -} +#[derive(Default)] +pub struct AnalysisService {} pub fn dep_nodes( - graph: &petgraph::Graph, + graph: &Graph, node: NodeIndex, visited: &mut HashSet, ) -> Vec { let mut depnodes = Vec::new(); fn dfs( - graph: &petgraph::Graph, + graph: &Graph, node: NodeIndex, depnodes: &mut Vec, visited: &mut HashSet, @@ -82,7 +77,7 @@ pub fn dep_nodes( } pub fn ancestor_nodes( - graph: &petgraph::Graph, + graph: &Graph, node: NodeIndex, ) -> Vec { let mut discovered = graph.visit_map(); @@ -131,8 +126,8 @@ pub fn ancestor_nodes( ancestor_nodes } -pub async fn get_implicit_relationships( - connection: &ConnectionOrTransaction<'_>, +pub async fn get_implicit_relationships( + connection: &C, distinct_sbom_id: &str, ) -> Result, DbErr> { let sql = r#" @@ -176,8 +171,8 @@ pub async fn get_implicit_relationships( Ok(results) } -pub async fn get_relationships( - connection: &ConnectionOrTransaction<'_>, +pub async fn get_relationships( + connection: &C, distinct_sbom_id: &str, ) -> Result, DbErr> { // Retrieve all SBOM components that have defined relationships @@ -232,10 +227,7 @@ pub async fn get_relationships( Ok(results) } -pub async fn load_graphs( - connection: &ConnectionOrTransaction<'_>, - distinct_sbom_ids: &Vec, -) { +pub async fn load_graphs(connection: &C, distinct_sbom_ids: &Vec) { let graph_map = GraphMap::get_instance(); { for distinct_sbom_id in distinct_sbom_ids { @@ -420,53 +412,53 @@ pub async fn load_graphs( } impl AnalysisService { - pub fn new(db: Database) -> Self { - GraphMap::get_instance(); - Self { db } + pub fn new() -> Self { + let _ = GraphMap::get_instance(); + Self {} } - pub async fn load_graphs>( + pub async fn load_graphs( &self, distinct_sbom_ids: Vec, - tx: TX, + connection: &C, ) -> Result<(), Error> { - let connection = self.db.connection(&tx); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; Ok(()) } - pub async fn load_all_graphs>(&self, tx: TX) -> Result<(), Error> { - let connection = self.db.connection(&tx); + pub async fn load_all_graphs(&self, connection: &C) -> Result<(), Error> { // retrieve all sboms in trustify let distinct_sbom_ids = sbom::Entity::find() .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; Ok(()) } - pub async fn clear_all_graphs>(&self, _tx: TX) -> Result<(), Error> { + pub async fn clear_all_graphs(&self) -> Result<(), Error> { let graph_manager = GraphMap::get_instance(); let mut manager = graph_manager.write(); manager.clear(); Ok(()) } - pub async fn status>(&self, tx: TX) -> Result { - let connection = self.db.connection(&tx); + pub async fn status( + &self, + connection: &C, + ) -> Result { let distinct_sbom_ids = sbom::Entity::find() .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await?; let graph_manager = GraphMap::get_instance(); @@ -477,7 +469,7 @@ impl AnalysisService { }) } - pub async fn query_ancestor_graph>( + pub async fn query_ancestor_graph( component_name: Option, component_purl: Option, query: Option, @@ -574,7 +566,7 @@ impl AnalysisService { components } - pub async fn query_deps_graph>( + pub async fn query_deps_graph( component_name: Option, component_purl: Option, query: Option, @@ -671,15 +663,13 @@ impl AnalysisService { components } - #[instrument(skip(self, tx), err)] - pub async fn retrieve_root_components>( + #[instrument(skip(self, connection), err)] + pub async fn retrieve_root_components( &self, query: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let search_sbom_node_name_subquery = sbom_node::Entity::find() .filtering(query.clone())? .select_only() @@ -691,15 +681,15 @@ impl AnalysisService { .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_ancestor_graph::( + let components = AnalysisService::query_ancestor_graph( None, None, Option::from(query), @@ -710,20 +700,19 @@ impl AnalysisService { Ok(paginated.paginate_array(&components)) } - pub async fn retrieve_all_sbom_roots_by_name>( + pub async fn retrieve_all_sbom_roots_by_name( &self, sbom_id: Uuid, component_name: String, - tx: TX, + connection: &C, ) -> Result, Error> { // This function searches for a component(s) by name in a specific sbom, then returns that components // root components. - let connection = self.db.connection(&tx); let distinct_sbom_ids = vec![sbom_id.to_string()]; - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_ancestor_graph::( + let components = AnalysisService::query_ancestor_graph( Option::from(component_name), None, None, @@ -744,15 +733,13 @@ impl AnalysisService { Ok(root_components) } - #[instrument(skip(self, tx), err)] - pub async fn retrieve_root_components_by_name>( + #[instrument(skip(self, connection), err)] + pub async fn retrieve_root_components_by_name( &self, component_name: String, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let search_sbom_node_exact_name_subquery = sbom_node::Entity::find() .filter(sbom_node::Column::Name.eq(component_name.as_str())) .select_only() @@ -764,15 +751,15 @@ impl AnalysisService { .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_ancestor_graph::( + let components = AnalysisService::query_ancestor_graph( Option::from(component_name), None, None, @@ -783,15 +770,13 @@ impl AnalysisService { Ok(paginated.paginate_array(&components)) } - #[instrument(skip(self, tx), err)] - pub async fn retrieve_root_components_by_purl>( + #[instrument(skip(self, connection), err)] + pub async fn retrieve_root_components_by_purl( &self, component_purl: Purl, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let search_sbom_node_exact_name_subquery = sbom_node::Entity::find() .filter(sbom_node::Column::Name.eq(component_purl.name.as_str())) .select_only() @@ -803,15 +788,15 @@ impl AnalysisService { .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_ancestor_graph::( + let components = AnalysisService::query_ancestor_graph( None, Option::from(component_purl), None, @@ -822,15 +807,13 @@ impl AnalysisService { Ok(paginated.paginate_array(&components)) } - #[instrument(skip(self, tx), err)] - pub async fn retrieve_deps>( + #[instrument(skip(self, connection), err)] + pub async fn retrieve_deps( &self, query: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let search_sbom_node_name_subquery = sbom_node::Entity::find() .filtering(query.clone())? .select_only() @@ -842,33 +825,27 @@ impl AnalysisService { .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_deps_graph::( - None, - None, - Option::from(query), - distinct_sbom_ids, - ) - .await; + let components = + AnalysisService::query_deps_graph(None, None, Option::from(query), distinct_sbom_ids) + .await; Ok(paginated.paginate_array(&components)) } - pub async fn retrieve_deps_by_name>( + pub async fn retrieve_deps_by_name( &self, component_name: String, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let search_sbom_node_exact_name_subquery = sbom_node::Entity::find() .filter(sbom_node::Column::Name.eq(component_name.as_str())) .select_only() @@ -880,15 +857,15 @@ impl AnalysisService { .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_deps_graph::( + let components = AnalysisService::query_deps_graph( Option::from(component_name), None, None, @@ -899,14 +876,12 @@ impl AnalysisService { Ok(paginated.paginate_array(&components)) } - pub async fn retrieve_deps_by_purl>( + pub async fn retrieve_deps_by_purl( &self, component_purl: Purl, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let search_sbom_node_exact_name_subquery = sbom_node::Entity::find() .filter(sbom_node::Column::Name.eq(component_purl.name.as_str())) .select_only() @@ -918,15 +893,15 @@ impl AnalysisService { .select() .order_by(sbom::Column::DocumentId, Order::Asc) .order_by(sbom::Column::Published, Order::Desc) - .all(&connection) + .all(connection) .await? .into_iter() .map(|record| record.sbom_id.to_string()) // Assuming sbom_id is of type String .collect(); - load_graphs(&connection, &distinct_sbom_ids).await; + load_graphs(connection, &distinct_sbom_ids).await; - let components = AnalysisService::query_deps_graph::( + let components = AnalysisService::query_deps_graph( None, Option::from(component_purl), None, @@ -953,10 +928,10 @@ mod test { ctx.ingest_documents(["spdx/simple.json", "spdx/simple.json"]) .await?; //double ingestion intended - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_root_components(Query::q("DD"), Paginated::default(), ()) + .retrieve_root_components(Query::q("DD"), Paginated::default(), &ctx.db) .await .unwrap(); @@ -986,7 +961,7 @@ mod test { // ensure we set implicit relationship on component with no defined relationships let analysis_graph = service - .retrieve_root_components(Query::q("EE"), Paginated::default(), ()) + .retrieve_root_components(Query::q("EE"), Paginated::default(), &ctx.db) .await .unwrap(); Ok(assert_eq!(analysis_graph.total, 1)) @@ -1000,10 +975,10 @@ mod test { ctx.ingest_documents(["cyclonedx/simple.json", "cyclonedx/simple.json"]) .await?; //double ingestion intended - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_root_components(Query::q("DD"), Paginated::default(), ()) + .retrieve_root_components(Query::q("DD"), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1031,7 +1006,7 @@ mod test { // ensure we set implicit relationship on component with no defined relationships let analysis_graph = service - .retrieve_root_components(Query::q("EE"), Paginated::default(), ()) + .retrieve_root_components(Query::q("EE"), Paginated::default(), &ctx.db) .await .unwrap(); Ok(assert_eq!(analysis_graph.total, 1)) @@ -1044,10 +1019,10 @@ mod test { ) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_root_components_by_name("B".to_string(), Paginated::default(), ()) + .retrieve_root_components_by_name("B".to_string(), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1083,12 +1058,12 @@ mod test { ) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let component_purl: Purl = Purl::from_str("pkg:rpm/redhat/B@0.0.0").map_err(Error::Purl)?; let analysis_graph = service - .retrieve_root_components_by_purl(component_purl, Paginated::default(), ()) + .retrieve_root_components_by_purl(component_purl, Paginated::default(), &ctx.db) .await .unwrap(); @@ -1126,10 +1101,10 @@ mod test { ]) .await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_root_components(Query::q("spymemcached"), Paginated::default(), ()) + .retrieve_root_components(Query::q("spymemcached"), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1158,14 +1133,14 @@ mod test { async fn test_status_service(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); - let _load_all_graphs = service.load_all_graphs(()).await; - let analysis_status = service.status(()).await.unwrap(); + let service = AnalysisService::new(); + let _load_all_graphs = service.load_all_graphs(&ctx.db).await; + let analysis_status = service.status(&ctx.db).await.unwrap(); assert_eq!(analysis_status.sbom_count, 1); assert_eq!(analysis_status.graph_count, 1); - let _clear_all_graphs = service.clear_all_graphs(()).await; + let _clear_all_graphs = service.clear_all_graphs().await; ctx.ingest_documents([ "spdx/quarkus-bom-3.2.11.Final-redhat-00001.json", @@ -1173,7 +1148,7 @@ mod test { ]) .await?; - let analysis_status = service.status(()).await.unwrap(); + let analysis_status = service.status(&ctx.db).await.unwrap(); assert_eq!(analysis_status.sbom_count, 3); assert_eq!(analysis_status.graph_count, 0); @@ -1186,10 +1161,10 @@ mod test { async fn test_simple_deps_service(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_deps(Query::q("AA"), Paginated::default(), ()) + .retrieve_deps(Query::q("AA"), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1197,7 +1172,7 @@ mod test { // ensure we set implicit relationship on component with no defined relationships let analysis_graph = service - .retrieve_root_components(Query::q("EE"), Paginated::default(), ()) + .retrieve_root_components(Query::q("EE"), Paginated::default(), &ctx.db) .await .unwrap(); Ok(assert_eq!(analysis_graph.total, 1)) @@ -1210,10 +1185,10 @@ mod test { ) -> Result<(), anyhow::Error> { ctx.ingest_documents(["cyclonedx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_deps(Query::q("AA"), Paginated::default(), ()) + .retrieve_deps(Query::q("AA"), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1221,7 +1196,7 @@ mod test { // ensure we set implicit relationship on component with no defined relationships let analysis_graph = service - .retrieve_root_components(Query::q("EE"), Paginated::default(), ()) + .retrieve_root_components(Query::q("EE"), Paginated::default(), &ctx.db) .await .unwrap(); Ok(assert_eq!(analysis_graph.total, 1)) @@ -1232,10 +1207,10 @@ mod test { async fn test_simple_by_name_deps_service(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_deps_by_name("A".to_string(), Paginated::default(), ()) + .retrieve_deps_by_name("A".to_string(), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1251,13 +1226,13 @@ mod test { async fn test_simple_by_purl_deps_service(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/simple.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let component_purl: Purl = Purl::from_str("pkg:rpm/redhat/AA@0.0.0?arch=src").map_err(Error::Purl)?; let analysis_graph = service - .retrieve_deps_by_purl(component_purl, Paginated::default(), ()) + .retrieve_deps_by_purl(component_purl, Paginated::default(), &ctx.db) .await .unwrap(); @@ -1278,10 +1253,10 @@ mod test { ]) .await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_deps(Query::q("spymemcached"), Paginated::default(), ()) + .retrieve_deps(Query::q("spymemcached"), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1296,10 +1271,10 @@ mod test { ctx.ingest_documents(["cyclonedx/cyclonedx-circular.json"]) .await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_deps_by_name("junit-bom".to_string(), Paginated::default(), ()) + .retrieve_deps_by_name("junit-bom".to_string(), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1311,10 +1286,10 @@ mod test { async fn test_circular_deps_spdx_service(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.ingest_documents(["spdx/loop.json"]).await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let analysis_graph = service - .retrieve_deps_by_name("A".to_string(), Paginated::default(), ()) + .retrieve_deps_by_name("A".to_string(), Paginated::default(), &ctx.db) .await .unwrap(); @@ -1329,11 +1304,11 @@ mod test { ctx.ingest_documents(["spdx/quarkus-bom-3.2.11.Final-redhat-00001.json"]) .await?; - let service = AnalysisService::new(ctx.db.clone()); + let service = AnalysisService::new(); let component_name = "quarkus-vertx-http".to_string(); let analysis_graph = service - .retrieve_root_components(Query::q(&component_name), Paginated::default(), ()) + .retrieve_root_components(Query::q(&component_name), Paginated::default(), &ctx.db) .await?; let sbom_id = analysis_graph @@ -1344,7 +1319,7 @@ mod test { .parse::()?; let roots = service - .retrieve_all_sbom_roots_by_name(sbom_id, component_name, ()) + .retrieve_all_sbom_roots_by_name(sbom_id, component_name, &ctx.db) .await?; assert_eq!(roots.last().unwrap().name, "quarkus-bom"); diff --git a/modules/fundamental/benches/bench.rs b/modules/fundamental/benches/bench.rs index 71e21749e..9058d2d3b 100644 --- a/modules/fundamental/benches/bench.rs +++ b/modules/fundamental/benches/bench.rs @@ -22,8 +22,6 @@ pub(crate) mod trustify_benches { use sea_orm::ConnectionTrait; use test_context::AsyncTestContext; use tokio::runtime::Runtime; - - use trustify_common::db::Transactional; use trustify_entity::labels::Labels; use trustify_module_ingestor::service::Format; use trustify_test_context::{document, TrustifyContext}; @@ -156,15 +154,11 @@ pub(crate) mod trustify_benches { "vulnerability", ] { ctx.db - .clone() - .connection(&Transactional::None) .execute_unprepared(format!("DELETE FROM {table} WHERE 1=1").as_str()) .await .expect("DELETE ok"); } ctx.db - .clone() - .connection(&Transactional::None) .execute_unprepared("VACUUM ANALYZE") .await .expect("vacuum analyze ok"); diff --git a/modules/fundamental/src/advisory/endpoints/label.rs b/modules/fundamental/src/advisory/endpoints/label.rs index 678e5f0ca..fcfed8967 100644 --- a/modules/fundamental/src/advisory/endpoints/label.rs +++ b/modules/fundamental/src/advisory/endpoints/label.rs @@ -1,6 +1,7 @@ use crate::advisory::service::AdvisoryService; use actix_web::{patch, put, web, HttpResponse, Responder}; use trustify_auth::{authorizer::Require, UpdateAdvisory}; +use trustify_common::db::Database; use trustify_common::id::Id; use trustify_entity::labels::Labels; @@ -20,12 +21,16 @@ use trustify_entity::labels::Labels; #[put("/v1/advisory/{id}/label")] pub async fn set( advisory: web::Data, + db: web::Data, id: web::Path, web::Json(labels): web::Json, _: Require, ) -> actix_web::Result { Ok( - match advisory.set_labels(id.into_inner(), labels, ()).await? { + match advisory + .set_labels(id.into_inner(), labels, db.as_ref()) + .await? + { Some(()) => HttpResponse::NoContent(), None => HttpResponse::NotFound(), }, diff --git a/modules/fundamental/src/advisory/endpoints/mod.rs b/modules/fundamental/src/advisory/endpoints/mod.rs index 48e761ff7..5b991458c 100644 --- a/modules/fundamental/src/advisory/endpoints/mod.rs +++ b/modules/fundamental/src/advisory/endpoints/mod.rs @@ -10,11 +10,12 @@ use crate::{ }, endpoints::Deprecation, purl::service::PurlService, - Error::{self, Internal}, + Error, }; use actix_web::{delete, get, http::header, post, web, HttpResponse, Responder}; use config::Config; use futures_util::TryStreamExt; +use sea_orm::TransactionTrait; use std::str::FromStr; use trustify_auth::authorizer::Require; use trustify_auth::{CreateAdvisory, DeleteAdvisory, ReadAdvisory}; @@ -35,9 +36,10 @@ pub fn configure( upload_limit: usize, ) { let advisory_service = AdvisoryService::new(db.clone()); - let purl_service = PurlService::new(db); + let purl_service = PurlService::new(); config + .app_data(web::Data::new(db)) .app_data(web::Data::new(advisory_service)) .app_data(web::Data::new(purl_service)) .app_data(web::Data::new(Config { upload_limit })) @@ -66,6 +68,7 @@ pub fn configure( /// List advisories pub async fn all( state: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, web::Query(Deprecation { deprecated }): web::Query, @@ -73,7 +76,7 @@ pub async fn all( ) -> actix_web::Result { Ok(HttpResponse::Ok().json( state - .fetch_advisories(search, paginated, deprecated, ()) + .fetch_advisories(search, paginated, deprecated, db.as_ref()) .await?, )) } @@ -93,11 +96,12 @@ pub async fn all( /// Get an advisory pub async fn get( state: web::Data, + db: web::Data, key: web::Path, _: Require, ) -> actix_web::Result { let hash_key = Id::from_str(&key).map_err(Error::IdKey)?; - let fetched = state.fetch_advisory(hash_key, ()).await?; + let fetched = state.fetch_advisory(hash_key, db.as_ref()).await?; if let Some(fetched) = fetched { Ok(HttpResponse::Ok().json(fetched)) @@ -121,22 +125,26 @@ pub async fn get( /// Delete an advisory pub async fn delete( state: web::Data, + db: web::Data, purl_service: web::Data, key: web::Path, _: Require, -) -> actix_web::Result { - let hash_key = Id::from_str(&key).map_err(Error::IdKey)?; - let fetched = state.fetch_advisory(hash_key, ()).await?; +) -> Result { + let tx = db.begin().await?; + + let hash_key = Id::from_str(&key)?; + let fetched = state.fetch_advisory(hash_key, &tx).await?; if let Some(fetched) = fetched { - let rows_affected = state.delete_advisory(fetched.head.uuid, ()).await?; + let rows_affected = state.delete_advisory(fetched.head.uuid, &tx).await?; match rows_affected { 0 => Ok(HttpResponse::NotFound().finish()), 1 => { - _ = purl_service.gc_purls(()).await; // ignore gc failure.. + let _ = purl_service.gc_purls(&tx).await; // ignore gc failure.. + tx.commit().await?; Ok(HttpResponse::Ok().json(fetched)) } - _ => Err(Internal("Unexpected number of rows affected".into()).into()), + _ => Err(Error::Internal("Unexpected number of rows affected".into())), } } else { Ok(HttpResponse::NotFound().finish()) @@ -199,6 +207,7 @@ pub async fn upload( #[get("/v1/advisory/{key}/download")] /// Download an advisory document pub async fn download( + db: web::Data, ingestor: web::Data, advisory: web::Data, key: web::Path, @@ -208,7 +217,7 @@ pub async fn download( let id = Id::from_str(&key).map_err(Error::IdKey)?; // look up document by id - let Some(advisory) = advisory.fetch_advisory(id, ()).await? else { + let Some(advisory) = advisory.fetch_advisory(id, db.as_ref()).await? else { return Ok(HttpResponse::NotFound().finish()); }; diff --git a/modules/fundamental/src/advisory/endpoints/test.rs b/modules/fundamental/src/advisory/endpoints/test.rs index 47e868c40..a7b04779f 100644 --- a/modules/fundamental/src/advisory/endpoints/test.rs +++ b/modules/fundamental/src/advisory/endpoints/test.rs @@ -11,7 +11,7 @@ use sha2::{Digest, Sha256}; use test_context::test_context; use test_log::test; use time::OffsetDateTime; -use trustify_common::{db::Transactional, hashing::Digests, id::Id, model::PaginatedResults}; +use trustify_common::{hashing::Digests, id::Id, model::PaginatedResults}; use trustify_cvss::cvss3::{ AttackComplexity, AttackVector, Availability, Confidentiality, Cvss3Base, Integrity, PrivilegesRequired, Scope, UserInteraction, @@ -41,12 +41,12 @@ async fn all_advisories(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln .ingest_cvss3_score( @@ -61,7 +61,7 @@ async fn all_advisories(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { i: Integrity::None, a: Availability::None, }, - (), + &ctx.db, ) .await?; @@ -79,7 +79,7 @@ async fn all_advisories(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; @@ -128,7 +128,7 @@ async fn one_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; @@ -147,12 +147,12 @@ async fn one_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; let advisory_vuln = advisory2 - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln .ingest_cvss3_score( @@ -167,7 +167,7 @@ async fn one_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { i: Integrity::None, a: Availability::None, }, - (), + &ctx.db, ) .await?; @@ -225,7 +225,7 @@ async fn one_advisory_by_uuid(ctx: &TrustifyContext) -> Result<(), anyhow::Error modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; @@ -244,14 +244,14 @@ async fn one_advisory_by_uuid(ctx: &TrustifyContext) -> Result<(), anyhow::Error modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; let uuid = advisory.advisory.id; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln .ingest_cvss3_score( @@ -266,7 +266,7 @@ async fn one_advisory_by_uuid(ctx: &TrustifyContext) -> Result<(), anyhow::Error i: Integrity::None, a: Availability::None, }, - (), + &ctx.db, ) .await?; diff --git a/modules/fundamental/src/advisory/model/details/advisory_vulnerability.rs b/modules/fundamental/src/advisory/model/details/advisory_vulnerability.rs index a22ba465a..16d38875c 100644 --- a/modules/fundamental/src/advisory/model/details/advisory_vulnerability.rs +++ b/modules/fundamental/src/advisory/model/details/advisory_vulnerability.rs @@ -1,7 +1,6 @@ use crate::{vulnerability::model::VulnerabilityHead, Error}; -use sea_orm::{ColumnTrait, EntityTrait, LoaderTrait, QueryFilter}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, LoaderTrait, QueryFilter}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::memo::Memo; use trustify_cvss::cvss3::severity::Severity; use trustify_cvss::{cvss3::score::Score, cvss3::Cvss3Base}; @@ -30,10 +29,10 @@ pub struct AdvisoryVulnerabilityHead { } impl AdvisoryVulnerabilityHead { - pub async fn from_entity( + pub async fn from_entity( advisory: &advisory::Model, vulnerability: &vulnerability::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let cvss3 = cvss3::Entity::find() .filter(cvss3::Column::AdvisoryId.eq(advisory.id)) @@ -68,10 +67,10 @@ impl AdvisoryVulnerabilityHead { } } - pub async fn from_entities( + pub async fn from_entities( advisory: &advisory::Model, vulnerabilities: &[vulnerability::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let cvss3s = vulnerabilities .load_many( @@ -122,10 +121,10 @@ pub struct AdvisoryVulnerabilitySummary { } impl AdvisoryVulnerabilitySummary { - pub async fn from_entities( + pub async fn from_entities( advisory: &advisory::Model, vulnerabilities: &[vulnerability::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut cvss3s = vulnerabilities .load_many( diff --git a/modules/fundamental/src/advisory/model/details/mod.rs b/modules/fundamental/src/advisory/model/details/mod.rs index 64de87b51..93e7fb3f5 100644 --- a/modules/fundamental/src/advisory/model/details/mod.rs +++ b/modules/fundamental/src/advisory/model/details/mod.rs @@ -4,9 +4,8 @@ use crate::advisory::service::AdvisoryCatcher; use crate::source_document::model::SourceDocument; use crate::{advisory::model::AdvisoryHead, Error}; use advisory_vulnerability::AdvisoryVulnerabilitySummary; -use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QuerySelect}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, QuerySelect}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::memo::Memo; use trustify_cvss::cvss3::severity::Severity; use trustify_entity::{self as entity}; @@ -33,9 +32,9 @@ pub struct AdvisoryDetails { } impl AdvisoryDetails { - pub async fn from_entity( + pub async fn from_entity( advisory: &AdvisoryCatcher, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let vulnerabilities = entity::vulnerability::Entity::find() .right_join(entity::advisory_vulnerability::Entity) @@ -59,7 +58,7 @@ impl AdvisoryDetails { ) .await?, source_document: if let Some(doc) = &advisory.source_document { - Some(SourceDocument::from_entity(doc, tx).await?) + Some(SourceDocument::from_entity(doc).await?) } else { None }, diff --git a/modules/fundamental/src/advisory/model/mod.rs b/modules/fundamental/src/advisory/model/mod.rs index cab4222da..efa1a94f2 100644 --- a/modules/fundamental/src/advisory/model/mod.rs +++ b/modules/fundamental/src/advisory/model/mod.rs @@ -6,10 +6,9 @@ pub use details::*; pub use summary::*; use crate::{organization::model::OrganizationSummary, Error}; -use sea_orm::{prelude::Uuid, LoaderTrait, ModelTrait}; +use sea_orm::{prelude::Uuid, ConnectionTrait, LoaderTrait, ModelTrait}; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::memo::Memo; use trustify_entity::{advisory, labels::Labels, organization}; use utoipa::ToSchema; @@ -55,19 +54,17 @@ pub struct AdvisoryHead { } impl AdvisoryHead { - pub async fn from_advisory( + pub async fn from_advisory( advisory: &advisory::Model, issuer: Memo, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let issuer = match &issuer { - Memo::Provided(Some(issuer)) => { - Some(OrganizationSummary::from_entity(issuer, tx).await?) - } + Memo::Provided(Some(issuer)) => Some(OrganizationSummary::from_entity(issuer).await?), Memo::Provided(None) => None, Memo::NotProvided => { if let Some(issuer) = advisory.find_related(organization::Entity).one(tx).await? { - Some(OrganizationSummary::from_entity(&issuer, tx).await?) + Some(OrganizationSummary::from_entity(&issuer).await?) } else { None } @@ -87,9 +84,9 @@ impl AdvisoryHead { }) } - pub async fn from_entities( + pub async fn from_entities( entities: &[advisory::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut heads = Vec::new(); @@ -97,7 +94,7 @@ impl AdvisoryHead { for (advisory, issuer) in entities.iter().zip(issuers) { let issuer = if let Some(issuer) = issuer { - Some(OrganizationSummary::from_entity(&issuer, tx).await?) + Some(OrganizationSummary::from_entity(&issuer).await?) } else { None }; diff --git a/modules/fundamental/src/advisory/model/summary.rs b/modules/fundamental/src/advisory/model/summary.rs index 89724da69..7ab5bfa51 100644 --- a/modules/fundamental/src/advisory/model/summary.rs +++ b/modules/fundamental/src/advisory/model/summary.rs @@ -1,6 +1,6 @@ -use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QuerySelect}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, QuerySelect}; use serde::{Deserialize, Serialize}; -use trustify_common::{db::ConnectionOrTransaction, memo::Memo}; +use trustify_common::memo::Memo; use trustify_cvss::cvss3::score::Score; use trustify_entity::{advisory_vulnerability, vulnerability}; use utoipa::ToSchema; @@ -32,9 +32,9 @@ pub struct AdvisorySummary { } impl AdvisorySummary { - pub async fn from_entities( + pub async fn from_entities( entities: &[AdvisoryCatcher], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut summaries = Vec::with_capacity(entities.len()); @@ -63,7 +63,7 @@ impl AdvisorySummary { ) .await?, source_document: if let Some(doc) = &each.source_document { - Some(SourceDocument::from_entity(doc, tx).await?) + Some(SourceDocument::from_entity(doc).await?) } else { None }, diff --git a/modules/fundamental/src/advisory/service/mod.rs b/modules/fundamental/src/advisory/service/mod.rs index 53a3e2162..ac2b193eb 100644 --- a/modules/fundamental/src/advisory/service/mod.rs +++ b/modules/fundamental/src/advisory/service/mod.rs @@ -14,7 +14,7 @@ use trustify_common::{ limiter::LimiterAsModelTrait, multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, query::{Columns, Filtering, Query}, - Database, Transactional, + Database, }, id::{Id, TrySelectForId}, model::{Paginated, PaginatedResults}, @@ -37,15 +37,13 @@ impl AdvisoryService { Self { db } } - pub async fn fetch_advisories + Sync + Send>( + pub async fn fetch_advisories( &self, search: Query, paginated: Paginated, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - // To be able to ORDER or WHERE using a synthetic column, we must first // SELECT col, extra_col FROM (SELECT col, random as extra_col FROM...) // which involves mucking about inside the Select to re-target from @@ -119,7 +117,7 @@ impl AdvisoryService { }), )? .try_limiting_as_multi_model::( - &connection, + connection, paginated.offset, paginated.limit, )?; @@ -130,17 +128,15 @@ impl AdvisoryService { Ok(PaginatedResults { total, - items: AdvisorySummary::from_entities(&items, &connection).await?, + items: AdvisorySummary::from_entities(&items, connection).await?, }) } - pub async fn fetch_advisory + Sync + Send>( + pub async fn fetch_advisory( &self, id: Id, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - // To be able to ORDER or WHERE using a synthetic column, we must first // SELECT col, extra_col FROM (SELECT col, random as extra_col FROM...) // which involves mucking about inside the Select to re-target from @@ -189,26 +185,24 @@ impl AdvisoryService { ) .try_filter(id)? .try_into_multi_model::()? - .one(&connection) + .one(connection) .await?; if let Some(catcher) = results { Ok(Some( - AdvisoryDetails::from_entity(&catcher, &connection).await?, + AdvisoryDetails::from_entity(&catcher, connection).await?, )) } else { Ok(None) } } - /// delete one sbom - pub async fn delete_advisory>( + /// delete one advisory + pub async fn delete_advisory( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result { - let connection = self.db.connection(&tx); - let stmt = Statement::from_sql_and_values( connection.get_database_backend(), r#"DELETE FROM advisory WHERE id=$1 RETURNING identifier"#, @@ -220,7 +214,7 @@ impl AdvisoryService { for row in result { let identifier = row.try_get_by_index::(0)?; - UpdateDeprecatedAdvisory::execute(&connection, &identifier).await?; + UpdateDeprecatedAdvisory::execute(connection, &identifier).await?; } Ok(rows_affected as u64) @@ -230,18 +224,16 @@ impl AdvisoryService { /// /// Returns `Ok(Some(()))` if a document was found and updated. If no document was found, it will /// return `Ok(None)`. - pub async fn set_labels( + pub async fn set_labels( &self, id: Id, labels: Labels, - tx: impl AsRef, + connection: &C, ) -> Result, Error> { - let db = self.db.connection(&tx); - let result = advisory::Entity::update_many() .try_filter(id)? .col_expr(advisory::Column::Labels, Expr::value(labels)) - .exec(&db) + .exec(connection) .await?; Ok((result.rows_affected > 0).then_some(())) diff --git a/modules/fundamental/src/advisory/service/test.rs b/modules/fundamental/src/advisory/service/test.rs index 217db2d6f..bd4f912ea 100644 --- a/modules/fundamental/src/advisory/service/test.rs +++ b/modules/fundamental/src/advisory/service/test.rs @@ -1,15 +1,13 @@ use super::*; -use crate::advisory::model::AdvisoryHead; -use crate::source_document::model::SourceDocument; +use crate::{advisory::model::AdvisoryHead, source_document::model::SourceDocument}; use std::str::FromStr; use test_context::test_context; use test_log::test; use time::OffsetDateTime; use trustify_common::{db::query::q, hashing::Digests, model::Paginated, purl::Purl}; -use trustify_cvss::cvss3::severity::Severity; use trustify_cvss::cvss3::{ - AttackComplexity, AttackVector, Availability, Confidentiality, Cvss3Base, Integrity, - PrivilegesRequired, Scope, UserInteraction, + severity::Severity, AttackComplexity, AttackVector, Availability, Confidentiality, Cvss3Base, + Integrity, PrivilegesRequired, Scope, UserInteraction, }; use trustify_entity::version_scheme::VersionScheme; use trustify_module_ingestor::graph::advisory::{ @@ -38,7 +36,7 @@ pub async fn ingest_sample_advisory<'a>( modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await } @@ -47,7 +45,7 @@ pub async fn ingest_and_link_advisory(ctx: &TrustifyContext) -> Result<(), anyho let advisory = ingest_sample_advisory(ctx, "RHSA-1", "RHSA-1").await?; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln @@ -63,7 +61,7 @@ pub async fn ingest_and_link_advisory(ctx: &TrustifyContext) -> Result<(), anyho i: Integrity::High, a: Availability::High, }, - (), + &ctx.db, ) .await?; Ok(()) @@ -78,7 +76,7 @@ async fn all_advisories(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let fetch = AdvisoryService::new(ctx.db.clone()); let fetched = fetch - .fetch_advisories(q(""), Paginated::default(), Default::default(), ()) + .fetch_advisories(q(""), Paginated::default(), Default::default(), &ctx.db) .await?; assert_eq!(fetched.total, 2); @@ -100,7 +98,7 @@ async fn all_advisories_filtered_by_average_score( q("average_score>8"), Paginated::default(), Default::default(), - (), + &ctx.db, ) .await?; @@ -123,7 +121,7 @@ async fn all_advisories_filtered_by_average_severity( q("average_severity>=critical"), Paginated::default(), Default::default(), - (), + &ctx.db, ) .await?; @@ -141,7 +139,7 @@ async fn single_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let advisory = ingest_sample_advisory(ctx, "RHSA-1", "RHSA-1").await?; let advisory_vuln: trustify_module_ingestor::graph::advisory::advisory_vulnerability::AdvisoryVulnerabilityContext<'_> = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None,&ctx.db) .await?; advisory_vuln .ingest_cvss3_score( @@ -156,7 +154,7 @@ async fn single_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { i: Integrity::High, a: Availability::High, }, - (), + &ctx.db, ) .await?; @@ -169,7 +167,7 @@ async fn single_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { scheme: VersionScheme::Maven, spec: VersionSpec::Exact("1.2.3".to_string()), }, - (), + &ctx.db, ) .await?; @@ -182,7 +180,7 @@ async fn single_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { scheme: VersionScheme::Maven, spec: VersionSpec::Exact("1.2.3".to_string()), }, - (), + &ctx.db, ) .await?; @@ -192,7 +190,7 @@ async fn single_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let jenny256 = Id::sha256(&digests.sha256); let jenny384 = Id::sha384(&digests.sha384); let jenny512 = Id::sha512(&digests.sha512); - let fetched = fetch.fetch_advisory(jenny256.clone(), ()).await?; + let fetched = fetch.fetch_advisory(jenny256.clone(), &ctx.db).await?; let id = Id::Uuid(fetched.as_ref().unwrap().head.uuid); assert!(matches!( @@ -211,7 +209,7 @@ async fn single_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { }) if sha256 == jenny256.to_string() && sha384 == jenny384.to_string() && sha512 == jenny512.to_string() && average_severity == Severity::Critical)); - let fetched = fetch.fetch_advisory(id, ()).await?; + let fetched = fetch.fetch_advisory(id, &ctx.db).await?; assert!(matches!( fetched, Some(AdvisoryDetails { @@ -239,7 +237,7 @@ async fn delete_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let advisory = ingest_sample_advisory(ctx, "RHSA-1", "RHSA-1").await?; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln .ingest_cvss3_score( @@ -254,7 +252,7 @@ async fn delete_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { i: Integrity::High, a: Availability::High, }, - (), + &ctx.db, ) .await?; @@ -267,7 +265,7 @@ async fn delete_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { scheme: VersionScheme::Maven, spec: VersionSpec::Exact("1.2.3".to_string()), }, - (), + &ctx.db, ) .await?; @@ -280,20 +278,20 @@ async fn delete_advisory(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { scheme: VersionScheme::Maven, spec: VersionSpec::Exact("1.2.3".to_string()), }, - (), + &ctx.db, ) .await?; let fetch = AdvisoryService::new(ctx.db.clone()); let jenny256 = Id::sha256(&digests.sha256); - let fetched = fetch.fetch_advisory(jenny256.clone(), ()).await?; + let fetched = fetch.fetch_advisory(jenny256.clone(), &ctx.db).await?; let fetched = fetched.expect("Advisory not found"); - let affected = fetch.delete_advisory(fetched.head.uuid, ()).await?; + let affected = fetch.delete_advisory(fetched.head.uuid, &ctx.db).await?; assert_eq!(affected, 1); - let affected = fetch.delete_advisory(fetched.head.uuid, ()).await?; + let affected = fetch.delete_advisory(fetched.head.uuid, &ctx.db).await?; assert_eq!(affected, 0); Ok(()) diff --git a/modules/fundamental/src/ai/endpoints/mod.rs b/modules/fundamental/src/ai/endpoints/mod.rs index 8aa52e71d..260920249 100644 --- a/modules/fundamental/src/ai/endpoints/mod.rs +++ b/modules/fundamental/src/ai/endpoints/mod.rs @@ -35,10 +35,11 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d #[post("/v1/ai/completions")] pub async fn completions( service: web::Data, + db: web::Data, request: web::Json, _: Require, ) -> actix_web::Result { - let response = service.completions(&request, ()).await?; + let response = service.completions(&request, db.as_ref()).await?; Ok(HttpResponse::Ok().json(response)) } diff --git a/modules/fundamental/src/ai/service/mod.rs b/modules/fundamental/src/ai/service/mod.rs index 89a698590..235a943b6 100644 --- a/modules/fundamental/src/ai/service/mod.rs +++ b/modules/fundamental/src/ai/service/mod.rs @@ -19,10 +19,11 @@ use langchain_rust::{ prompt_args, tools::Tool, }; +use sea_orm::ConnectionTrait; use std::env; use std::sync::Arc; use tokio::sync::OnceCell; -use trustify_common::db::{Database, Transactional}; +use trustify_common::db::Database; pub const PREFIX: &str = include_str!("prefix.txt"); @@ -173,10 +174,10 @@ impl AiService { .await } - pub async fn completions>( + pub async fn completions( &self, request: &ChatState, - _tx: TX, + _connection: &C, ) -> Result { let llm = match self.llm.clone() { Some(llm) => llm, diff --git a/modules/fundamental/src/ai/service/test.rs b/modules/fundamental/src/ai/service/test.rs index bf1e811c0..46aea6a74 100644 --- a/modules/fundamental/src/ai/service/test.rs +++ b/modules/fundamental/src/ai/service/test.rs @@ -3,7 +3,6 @@ use crate::ai::service::AiService; use test_context::test_context; use test_log::test; -use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_module_ingestor::graph::product::ProductInformation; use trustify_test_context::TrustifyContext; @@ -16,7 +15,7 @@ pub async fn ingest_fixtures(ctx: &TrustifyContext) -> Result<(), anyhow::Error> &Digests::digest("RHSA-1"), "a", (), - Transactional::None, + &ctx.db, ) .await?; @@ -28,11 +27,11 @@ pub async fn ingest_fixtures(ctx: &TrustifyContext) -> Result<(), anyhow::Error> vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; - pr.ingest_product_version("37.17.9".to_string(), Some(sbom.sbom.sbom_id), ()) + pr.ingest_product_version("37.17.9".to_string(), Some(sbom.sbom.sbom_id), &ctx.db) .await?; ctx.ingest_documents(["osv/RUSTSEC-2021-0079.json", "cve/CVE-2021-32714.json"]) @@ -78,7 +77,7 @@ async fn test_completions_sbom_info(ctx: &TrustifyContext) -> Result<(), anyhow: .into(), ); - let result = service.completions(&req, ()).await?; + let result = service.completions(&req, &ctx.db).await?; log::info!("result: {:#?}", result); let last_message_content = result.messages.last().unwrap().content.clone(); @@ -106,7 +105,7 @@ async fn test_completions_package_info(ctx: &TrustifyContext) -> Result<(), anyh let mut req = ChatState::new(); req.add_human_message("List the httpclient packages with their identifiers".into()); - let result = service.completions(&req, ()).await?; + let result = service.completions(&req, &ctx.db).await?; log::info!("result: {:#?}", result); let last_message_content = result.messages.last().unwrap().content.clone(); @@ -135,7 +134,7 @@ async fn test_completions_cve_info(ctx: &TrustifyContext) -> Result<(), anyhow:: let mut req = ChatState::new(); req.add_human_message("Give me details for CVE-2021-32714".into()); - let result = service.completions(&req, ()).await?; + let result = service.completions(&req, &ctx.db).await?; log::info!("result: {:#?}", result); let last_message_content = result.messages.last().unwrap().content.clone(); @@ -163,7 +162,7 @@ async fn test_completions_advisory_info(ctx: &TrustifyContext) -> Result<(), any let mut req = ChatState::new(); req.add_human_message("Give me details for the RHSA-2024_3666 advisory".into()); - let result = service.completions(&req, ()).await?; + let result = service.completions(&req, &ctx.db).await?; log::info!("result: {:#?}", result); let last_message_content = result.messages.last().unwrap().content.clone(); diff --git a/modules/fundamental/src/ai/service/tools/advisory_info.rs b/modules/fundamental/src/ai/service/tools/advisory_info.rs index 1518620a4..5e2a2a23e 100644 --- a/modules/fundamental/src/ai/service/tools/advisory_info.rs +++ b/modules/fundamental/src/ai/service/tools/advisory_info.rs @@ -1,18 +1,29 @@ -use crate::advisory::service::AdvisoryService; -use crate::ai::service::tools; -use crate::ai::service::tools::input_description; +use crate::{ + advisory::service::AdvisoryService, + ai::service::tools::{self, input_description}, +}; use async_trait::async_trait; use langchain_rust::tools::Tool; use serde::Serialize; use serde_json::Value; use std::error::Error; use time::OffsetDateTime; -use trustify_common::db::query::Query; -use trustify_common::id::Id; +use trustify_common::db::Database; +use trustify_common::{db::query::Query, id::Id}; use trustify_module_ingestor::common::Deprecation; use uuid::Uuid; -pub struct AdvisoryInfo(pub AdvisoryService); +pub struct AdvisoryInfo { + db: Database, + service: AdvisoryService, +} + +impl AdvisoryInfo { + pub fn new(db: Database) -> Self { + let service = AdvisoryService::new(db.clone()); + Self { db, service } + } +} #[async_trait] impl Tool for AdvisoryInfo { @@ -20,10 +31,6 @@ impl Tool for AdvisoryInfo { String::from("advisory-info") } - fn parameters(&self) -> Value { - input_description("UUID of the Advisory. Example: 2fd0d1b7-a908-4d63-9310-d57a7f77c6df") - } - fn description(&self) -> String { String::from( r##" @@ -39,8 +46,12 @@ Advisories have a UUID that uniquely identifies the advisory. ) } + fn parameters(&self) -> Value { + input_description("UUID of the Advisory. Example: 2fd0d1b7-a908-4d63-9310-d57a7f77c6df") + } + async fn run(&self, input: Value) -> Result> { - let service = &self.0; + let service = &self.service; let input = input .as_str() @@ -48,7 +59,7 @@ Advisories have a UUID that uniquely identifies the advisory. .to_string(); let item = match Uuid::parse_str(input.as_str()).ok() { - Some(x) => service.fetch_advisory(Id::Uuid(x), ()).await?, + Some(x) => service.fetch_advisory(Id::Uuid(x), &self.db).await?, None => { // search for possible matches let results = service @@ -59,7 +70,7 @@ Advisories have a UUID that uniquely identifies the advisory. }, Default::default(), Deprecation::Ignore, - (), + &self.db, ) .await?; @@ -84,7 +95,7 @@ Advisories have a UUID that uniquely identifies the advisory. // let's show the details service - .fetch_advisory(Id::Uuid(results.items[0].head.uuid), ()) + .fetch_advisory(Id::Uuid(results.items[0].head.uuid), &self.db) .await? } }; @@ -152,7 +163,7 @@ mod tests { crate::advisory::service::test::ingest_and_link_advisory(ctx).await?; crate::advisory::service::test::ingest_sample_advisory(ctx, "RHSA-2", "RHSA-2").await?; - let tool = Rc::new(AdvisoryInfo(AdvisoryService::new(ctx.db.clone()))); + let tool = Rc::new(AdvisoryInfo::new(ctx.db.clone())); assert_tool_contains( tool.clone(), diff --git a/modules/fundamental/src/ai/service/tools/cve_info.rs b/modules/fundamental/src/ai/service/tools/cve_info.rs index f7d27c6ef..eaf570730 100644 --- a/modules/fundamental/src/ai/service/tools/cve_info.rs +++ b/modules/fundamental/src/ai/service/tools/cve_info.rs @@ -1,18 +1,32 @@ -use crate::ai::service::tools; -use crate::ai::service::tools::input_description; -use crate::vulnerability::service::VulnerabilityService; +use crate::{ + ai::service::tools::{self, input_description}, + vulnerability::service::VulnerabilityService, +}; use async_trait::async_trait; use langchain_rust::tools::Tool; use serde::Serialize; use serde_json::Value; -use std::error::Error; -use std::fmt::Write; +use std::{error::Error, fmt::Write}; use time::OffsetDateTime; -use trustify_common::db::query::Query; -use trustify_common::purl::Purl; +use trustify_common::{ + db::{query::Query, Database}, + purl::Purl, +}; use trustify_module_ingestor::common::Deprecation; -pub struct CVEInfo(pub VulnerabilityService); +pub struct CVEInfo { + pub db: Database, + pub service: VulnerabilityService, +} + +impl CVEInfo { + pub fn new(db: Database) -> Self { + Self { + db, + service: VulnerabilityService::new(), + } + } +} #[async_trait] impl Tool for CVEInfo { @@ -20,16 +34,6 @@ impl Tool for CVEInfo { String::from("cve-info") } - fn parameters(&self) -> Value { - input_description( - r#" -The input should be the partial or full name of the Vulnerability to search for. Example: -* CVE-2014-0160 - - "#, - ) - } - fn description(&self) -> String { String::from( r##" @@ -45,8 +49,18 @@ Vulnerability are identified by their CVE Identifier. ) } + fn parameters(&self) -> Value { + input_description( + r#" +The input should be the partial or full name of the Vulnerability to search for. Example: +* CVE-2014-0160 + + "#, + ) + } + async fn run(&self, input: Value) -> Result> { - let service = &self.0; + let service = &self.service; let input = input .as_str() @@ -54,7 +68,7 @@ Vulnerability are identified by their CVE Identifier. .to_string(); let item = match service - .fetch_vulnerability(input.as_str(), Deprecation::Ignore, ()) + .fetch_vulnerability(input.as_str(), Deprecation::Ignore, &self.db) .await? { Some(v) => v, @@ -68,7 +82,7 @@ Vulnerability are identified by their CVE Identifier. }, Default::default(), Deprecation::Ignore, - (), + &self.db, ) .await?; @@ -96,7 +110,7 @@ Vulnerability are identified by their CVE Identifier. .fetch_vulnerability( results.items[0].head.identifier.as_str(), Deprecation::Ignore, - (), + &self.db, ) .await? { @@ -171,7 +185,7 @@ mod tests { #[test(actix_web::test)] async fn cve_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ingest_fixtures(ctx).await?; - let tool = Rc::new(CVEInfo(VulnerabilityService::new(ctx.db.clone()))); + let tool = Rc::new(CVEInfo::new(ctx.db.clone())); assert_tool_contains( tool.clone(), "CVE-2021-32714", diff --git a/modules/fundamental/src/ai/service/tools/mod.rs b/modules/fundamental/src/ai/service/tools/mod.rs index ec7ef7004..cd2879321 100644 --- a/modules/fundamental/src/ai/service/tools/mod.rs +++ b/modules/fundamental/src/ai/service/tools/mod.rs @@ -1,19 +1,12 @@ -use crate::advisory::service::AdvisoryService; -use crate::ai::service::tools::advisory_info::AdvisoryInfo; -use crate::ai::service::tools::cve_info::CVEInfo; -use crate::ai::service::tools::logger::ToolLogger; -use crate::ai::service::tools::package_info::PackageInfo; -use crate::ai::service::tools::sbom_info::SbomInfo; -use crate::purl::service::PurlService; -use crate::sbom::service::SbomService; -use crate::vulnerability::service::VulnerabilityService; +use crate::ai::service::tools::{ + advisory_info::AdvisoryInfo, cve_info::CVEInfo, logger::ToolLogger, package_info::PackageInfo, + sbom_info::SbomInfo, +}; use langchain_rust::tools::Tool; use serde::Serialize; use serde_json::{json, Value}; -use std::error::Error; -use std::sync::Arc; -use trustify_common::db::Database; -use trustify_common::model::PaginatedResults; +use std::{error::Error, sync::Arc}; +use trustify_common::{db::Database, model::PaginatedResults}; pub mod advisory_info; pub mod cve_info; @@ -26,13 +19,10 @@ pub mod sbom_info; pub fn new(db: Database) -> Vec> { vec![ // Arc::new(ToolLogger(ProductInfo(ProductService::new(db.clone())))), - Arc::new(ToolLogger(CVEInfo(VulnerabilityService::new(db.clone())))), - Arc::new(ToolLogger(AdvisoryInfo(AdvisoryService::new(db.clone())))), - Arc::new(ToolLogger(PackageInfo(( - PurlService::new(db.clone()), - SbomService::new(db.clone()), - )))), - Arc::new(ToolLogger(SbomInfo(SbomService::new(db.clone())))), + Arc::new(ToolLogger(CVEInfo::new(db.clone()))), + Arc::new(ToolLogger(AdvisoryInfo::new(db.clone()))), + Arc::new(ToolLogger(PackageInfo::new(db.clone()))), + Arc::new(ToolLogger(SbomInfo::new(db.clone()))), ] } diff --git a/modules/fundamental/src/ai/service/tools/package_info.rs b/modules/fundamental/src/ai/service/tools/package_info.rs index a53643e94..9e0ef812d 100644 --- a/modules/fundamental/src/ai/service/tools/package_info.rs +++ b/modules/fundamental/src/ai/service/tools/package_info.rs @@ -1,17 +1,29 @@ -use crate::ai::service::tools; -use crate::purl::service::PurlService; -use crate::sbom::service::SbomService; +use crate::{ai::service::tools, purl::service::PurlService, sbom::service::SbomService}; use async_trait::async_trait; use langchain_rust::tools::Tool; use serde::Serialize; use serde_json::Value; use std::error::Error; -use trustify_common::db::query::Query; -use trustify_common::purl::Purl; +use trustify_common::{ + db::{query::Query, Database}, + purl::Purl, +}; use trustify_module_ingestor::common::Deprecation; use uuid::Uuid; -pub struct PackageInfo(pub (PurlService, SbomService)); +pub struct PackageInfo { + pub db: Database, + pub purl: PurlService, + pub sbom: SbomService, +} + +impl PackageInfo { + pub fn new(db: Database) -> Self { + let purl = PurlService::new(); + let sbom = SbomService::new(db.clone()); + Self { db, purl, sbom } + } +} #[async_trait] impl Tool for PackageInfo { @@ -43,7 +55,11 @@ Input: The package name, its Identifier URI, or UUID. } async fn run(&self, input: Value) -> Result> { - let (service, sbom_service) = &self.0; + let Self { + purl: service, + sbom: sbom_service, + db, + } = &self; let input = input .as_str() @@ -53,14 +69,14 @@ Input: The package name, its Identifier URI, or UUID. // Try lookup as a PURL let mut purl_details = match Purl::try_from(input.clone()) { Err(_) => None, - Ok(purl) => service.purl_by_purl(&purl, Deprecation::Ignore, ()).await?, + Ok(purl) => service.purl_by_purl(&purl, Deprecation::Ignore, db).await?, }; // Try lookup as a UUID if purl_details.is_none() { purl_details = match Uuid::parse_str(input.as_str()) { Err(_) => None, - Ok(uuid) => service.purl_by_uuid(&uuid, Deprecation::Ignore, ()).await?, + Ok(uuid) => service.purl_by_uuid(&uuid, Deprecation::Ignore, db).await?, }; } @@ -74,7 +90,7 @@ Input: The package name, its Identifier URI, or UUID. ..Default::default() }, Default::default(), - (), + &db, ) .await?; @@ -82,7 +98,7 @@ Input: The package name, its Identifier URI, or UUID. 0 => None, 1 => { service - .purl_by_uuid(&results.items[0].head.uuid, Deprecation::Ignore, ()) + .purl_by_uuid(&results.items[0].head.uuid, Deprecation::Ignore, db) .await? } _ => { @@ -111,7 +127,7 @@ Input: The package name, its Identifier URI, or UUID. }; let sboms = sbom_service - .find_related_sboms(item.head.uuid, Default::default(), Default::default(), ()) + .find_related_sboms(item.head.uuid, Default::default(), Default::default(), db) .await?; #[derive(Serialize)] @@ -205,10 +221,7 @@ mod tests { ctx.ingest_document("quarkus-bom-2.13.8.Final-redhat-00004.json") .await?; - let tool = Rc::new(PackageInfo(( - PurlService::new(ctx.db.clone()), - SbomService::new(ctx.db.clone()), - ))); + let tool = Rc::new(PackageInfo::new(ctx.db.clone())); assert_tool_contains( tool.clone(), diff --git a/modules/fundamental/src/ai/service/tools/product_info.rs b/modules/fundamental/src/ai/service/tools/product_info.rs index afd7c9246..5a1dc2685 100644 --- a/modules/fundamental/src/ai/service/tools/product_info.rs +++ b/modules/fundamental/src/ai/service/tools/product_info.rs @@ -1,15 +1,26 @@ -use crate::ai::service::tools; -use crate::ai::service::tools::input_description; -use crate::product::service::ProductService; +use crate::{ + ai::service::tools::{self, input_description}, + product::service::ProductService, +}; use async_trait::async_trait; use langchain_rust::tools::Tool; use serde::Serialize; use serde_json::Value; use std::error::Error; -use trustify_common::db::query::Query; +use trustify_common::db::{query::Query, Database}; use uuid::Uuid; -pub struct ProductInfo(pub ProductService); +pub struct ProductInfo { + pub db: Database, + pub service: ProductService, +} + +impl ProductInfo { + pub fn new(db: Database) -> Self { + let service = ProductService::new(); + Self { db, service } + } +} #[async_trait] impl Tool for ProductInfo { @@ -17,10 +28,6 @@ impl Tool for ProductInfo { String::from("product-info") } - fn parameters(&self) -> Value { - input_description("The name of the product to search for.") - } - fn description(&self) -> String { String::from( r##" @@ -34,12 +41,16 @@ Products are names of Software Products. Examples: * Quay "## - .trim(), + .trim(), ) } + fn parameters(&self) -> Value { + input_description("The name of the product to search for.") + } + async fn run(&self, input: Value) -> Result> { - let service = &self.0; + let service = &self.service; let input = input .as_str() .ok_or("Input should be a string")? @@ -52,7 +63,7 @@ Products are names of Software Products. Examples: ..Default::default() }, Default::default(), - (), + &self.db, ) .await?; @@ -90,7 +101,7 @@ mod tests { #[test(actix_web::test)] async fn product_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ingest_fixtures(ctx).await?; - let tool = Rc::new(ProductInfo(ProductService::new(ctx.db.clone()))); + let tool = Rc::new(ProductInfo::new(ctx.db.clone())); assert_tool_contains( tool.clone(), "Trusted Profile Analyzer", diff --git a/modules/fundamental/src/ai/service/tools/sbom_info.rs b/modules/fundamental/src/ai/service/tools/sbom_info.rs index 12b1fa7c9..a32e75622 100644 --- a/modules/fundamental/src/ai/service/tools/sbom_info.rs +++ b/modules/fundamental/src/ai/service/tools/sbom_info.rs @@ -1,20 +1,28 @@ -use crate::ai::service::tools; -use crate::sbom::service::SbomService; - -use crate::ai::service::tools::input_description; +use crate::{ + ai::service::tools::{self, input_description}, + sbom::service::SbomService, +}; use async_trait::async_trait; use itertools::Itertools; use langchain_rust::tools::Tool; use serde::Serialize; use serde_json::Value; -use std::error::Error; -use std::str::FromStr; +use std::{error::Error, str::FromStr}; use time::OffsetDateTime; -use trustify_common::db::query::Query; -use trustify_common::id::Id; +use trustify_common::{db::query::Query, db::Database, id::Id}; use uuid::Uuid; -pub struct SbomInfo(pub SbomService); +pub struct SbomInfo { + pub db: Database, + pub service: SbomService, +} + +impl SbomInfo { + pub fn new(db: Database) -> Self { + let service = SbomService::new(db.clone()); + Self { db, service } + } +} #[async_trait] impl Tool for SbomInfo { @@ -22,17 +30,6 @@ impl Tool for SbomInfo { String::from("sbom-info") } - fn parameters(&self) -> Value { - input_description( - r#" -An SBOM identifier or a product name. -A full SBOM name typically combines the product name and version (e.g., "product-version"). -If a user specifies both, use the product name get a list of best matching SBOMs. -For example, input "quarkus" instead of "quarkus 3.2.11". -"#, - ) - } - fn description(&self) -> String { String::from( r##" @@ -47,8 +44,19 @@ The tool provides a list of advisories/CVEs affecting the SBOM. ) } + fn parameters(&self) -> Value { + input_description( + r#" +An SBOM identifier or a product name. +A full SBOM name typically combines the product name and version (e.g., "product-version"). +If a user specifies both, use the product name get a list of best matching SBOMs. +For example, input "quarkus" instead of "quarkus 3.2.11". +"#, + ) + } + async fn run(&self, input: Value) -> Result> { - let service = &self.0; + let service = &self.service; let input = input .as_str() @@ -59,7 +67,7 @@ The tool provides a list of advisories/CVEs affecting the SBOM. Err(_) => None, Ok(id) => { log::info!("Fetching SBOM details by Id: {}", id); - service.fetch_sbom_details(id, ()).await? + service.fetch_sbom_details(id, &self.db).await? } }; @@ -68,7 +76,7 @@ The tool provides a list of advisories/CVEs affecting the SBOM. Err(_) => None, Ok(id) => { log::info!("Fetching SBOM details by UUID: {}", id); - service.fetch_sbom_details(Id::Uuid(id), ()).await? + service.fetch_sbom_details(Id::Uuid(id), &self.db).await? } }; } @@ -84,7 +92,7 @@ The tool provides a list of advisories/CVEs affecting the SBOM. }, Default::default(), (), - (), + &self.db, ) .await?; @@ -92,7 +100,7 @@ The tool provides a list of advisories/CVEs affecting the SBOM. 0 => None, 1 => { service - .fetch_sbom_details(Id::Uuid(results.items[0].head.id), ()) + .fetch_sbom_details(Id::Uuid(results.items[0].head.id), &self.db) .await? } _ => { @@ -221,7 +229,7 @@ mod tests { ctx.ingest_document("quarkus/v1/quarkus-bom-2.13.8.Final-redhat-00004.json") .await?; - let tool = Rc::new(SbomInfo(SbomService::new(ctx.db.clone()))); + let tool = Rc::new(SbomInfo::new(ctx.db.clone())); assert_tool_contains( tool.clone(), diff --git a/modules/fundamental/src/license/model/mod.rs b/modules/fundamental/src/license/model/mod.rs index c6ca4367f..f4ec7b7c9 100644 --- a/modules/fundamental/src/license/model/mod.rs +++ b/modules/fundamental/src/license/model/mod.rs @@ -1,7 +1,6 @@ use crate::{purl::model::VersionedPurlHead, sbom::model::SbomHead, Error}; -use sea_orm::{ModelTrait, PaginatorTrait}; +use sea_orm::{ConnectionTrait, ModelTrait, PaginatorTrait}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{license, purl_license_assertion}; use utoipa::ToSchema; use uuid::Uuid; @@ -18,11 +17,7 @@ pub struct LicenseSummary { } impl LicenseSummary { - pub async fn from_entity( - license: &license::Model, - purls: u64, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(license: &license::Model, purls: u64) -> Result { Ok(LicenseSummary { id: license.id, license: license.text.clone(), @@ -36,18 +31,18 @@ impl LicenseSummary { }) } - pub async fn from_entities( + pub async fn from_entities( licenses: &[license::Model], - tx: &ConnectionOrTransaction<'_>, + connection: &C, ) -> Result, Error> { let mut summaries = Vec::new(); for license in licenses { let purls = license .find_related(purl_license_assertion::Entity) - .count(tx) + .count(connection) .await?; - summaries.push(Self::from_entity(license, purls, tx).await?) + summaries.push(Self::from_entity(license, purls).await?) } Ok(summaries) diff --git a/modules/fundamental/src/license/service/mod.rs b/modules/fundamental/src/license/service/mod.rs index 4567389db..3c44d2977 100644 --- a/modules/fundamental/src/license/service/mod.rs +++ b/modules/fundamental/src/license/service/mod.rs @@ -16,7 +16,7 @@ use trustify_common::{ limiter::{LimiterAsModelTrait, LimiterTrait}, multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, query::{Filtering, Query}, - ConnectionOrTransaction, Database, + Database, }, model::{Paginated, PaginatedResults}, }; @@ -38,7 +38,6 @@ impl LicenseService { paginated: Paginated, ) -> Result, Error> { let tx = self.db.begin().await?; - let tx = (&tx).into(); let limiter = license::Entity::find().filtering(search)?.limiting( &self.db, @@ -56,16 +55,13 @@ impl LicenseService { pub async fn get_license(&self, id: Uuid) -> Result, Error> { let tx = self.db.begin().await?; - let tx = (&tx).into(); if let Some(license) = license::Entity::find_by_id(id).one(&tx).await? { let purls = license .find_related(purl_license_assertion::Entity) .count(&tx) .await?; - return Ok(Some( - LicenseSummary::from_entity(&license, purls, &tx).await?, - )); + return Ok(Some(LicenseSummary::from_entity(&license, purls).await?)); } Ok(None) @@ -108,7 +104,6 @@ impl LicenseService { } let tx = self.db.begin().await?; - let tx: ConnectionOrTransaction = (&tx).into(); let licensed_purls = versioned_purl::Entity::find() .join(JoinType::Join, versioned_purl::Relation::BasePurl.def()) diff --git a/modules/fundamental/src/organization/endpoints/mod.rs b/modules/fundamental/src/organization/endpoints/mod.rs index 42371a134..f60e1e9b7 100644 --- a/modules/fundamental/src/organization/endpoints/mod.rs +++ b/modules/fundamental/src/organization/endpoints/mod.rs @@ -14,8 +14,9 @@ use trustify_common::{ use uuid::Uuid; pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) { - let service = OrganizationService::new(db); + let service = OrganizationService::new(); config + .app_data(web::Data::new(db)) .app_data(web::Data::new(service)) .service(all) .service(get); @@ -36,11 +37,16 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d /// List organizations pub async fn all( state: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { - Ok(HttpResponse::Ok().json(state.fetch_organizations(search, paginated, ()).await?)) + Ok(HttpResponse::Ok().json( + state + .fetch_organizations(search, paginated, db.as_ref()) + .await?, + )) } #[utoipa::path( @@ -58,10 +64,11 @@ pub async fn all( /// Retrieve organization details pub async fn get( state: web::Data, + db: web::Data, id: web::Path, _: Require, ) -> actix_web::Result { - let fetched = state.fetch_organization(*id, ()).await?; + let fetched = state.fetch_organization(*id, db.as_ref()).await?; if let Some(fetched) = fetched { Ok(HttpResponse::Ok().json(fetched)) diff --git a/modules/fundamental/src/organization/endpoints/test.rs b/modules/fundamental/src/organization/endpoints/test.rs index bb357ffbe..ebf4b7959 100644 --- a/modules/fundamental/src/organization/endpoints/test.rs +++ b/modules/fundamental/src/organization/endpoints/test.rs @@ -6,7 +6,6 @@ use serde_json::{json, Value}; use test_context::test_context; use test_log::test; use trustify_common::db::query::Query; -use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_common::model::Paginated; use trustify_module_ingestor::graph::advisory::AdvisoryInformation; @@ -31,7 +30,7 @@ async fn all_organizations(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; @@ -49,7 +48,7 @@ async fn all_organizations(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; @@ -92,18 +91,18 @@ async fn one_organization(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; - let service = crate::organization::service::OrganizationService::new(ctx.db.clone()); + let service = crate::organization::service::OrganizationService::new(); let orgs = service - .fetch_organizations(Query::default(), Paginated::default(), ()) + .fetch_organizations(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, orgs.total); diff --git a/modules/fundamental/src/organization/model/details/mod.rs b/modules/fundamental/src/organization/model/details/mod.rs index 8d32f33d4..de3328c61 100644 --- a/modules/fundamental/src/organization/model/details/mod.rs +++ b/modules/fundamental/src/organization/model/details/mod.rs @@ -1,9 +1,8 @@ -use sea_orm::ModelTrait; +use sea_orm::{ConnectionTrait, ModelTrait}; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use crate::advisory::model::AdvisoryHead; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{advisory, organization}; use crate::organization::model::OrganizationHead; @@ -19,13 +18,13 @@ pub struct OrganizationDetails { } impl OrganizationDetails { - pub async fn from_entity( + pub async fn from_entity( org: &organization::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let advisories = org.find_related(advisory::Entity).all(tx).await?; Ok(OrganizationDetails { - head: OrganizationHead::from_entity(org, tx).await?, + head: OrganizationHead::from_entity(org).await?, advisories: AdvisoryHead::from_entities(&advisories, tx).await?, }) } diff --git a/modules/fundamental/src/organization/model/mod.rs b/modules/fundamental/src/organization/model/mod.rs index 7b24ea704..ebc7cca85 100644 --- a/modules/fundamental/src/organization/model/mod.rs +++ b/modules/fundamental/src/organization/model/mod.rs @@ -8,7 +8,6 @@ mod summary; use crate::Error; pub use details::*; pub use summary::*; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::organization; /// An organization who may issue advisories, product SBOMs, or @@ -31,10 +30,7 @@ pub struct OrganizationHead { } impl OrganizationHead { - pub async fn from_entity( - organization: &organization::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(organization: &organization::Model) -> Result { Ok(OrganizationHead { id: organization.id, name: organization.name.clone(), diff --git a/modules/fundamental/src/organization/model/summary.rs b/modules/fundamental/src/organization/model/summary.rs index 9b856edbc..622d73a5e 100644 --- a/modules/fundamental/src/organization/model/summary.rs +++ b/modules/fundamental/src/organization/model/summary.rs @@ -1,7 +1,6 @@ use crate::organization::model::OrganizationHead; use crate::Error; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::organization; use utoipa::ToSchema; @@ -12,24 +11,18 @@ pub struct OrganizationSummary { } impl OrganizationSummary { - pub async fn from_entity( - organization: &organization::Model, - tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(organization: &organization::Model) -> Result { Ok(OrganizationSummary { - head: OrganizationHead::from_entity(organization, tx).await?, + head: OrganizationHead::from_entity(organization).await?, }) } - pub async fn from_entities( - organizations: &[organization::Model], - tx: &ConnectionOrTransaction<'_>, - ) -> Result, Error> { + pub async fn from_entities(organizations: &[organization::Model]) -> Result, Error> { let mut summaries = Vec::new(); for org in organizations { summaries.push(OrganizationSummary { - head: OrganizationHead::from_entity(org, tx).await?, + head: OrganizationHead::from_entity(org).await?, }); } diff --git a/modules/fundamental/src/organization/service/mod.rs b/modules/fundamental/src/organization/service/mod.rs index c07e81b3a..c2c196add 100644 --- a/modules/fundamental/src/organization/service/mod.rs +++ b/modules/fundamental/src/organization/service/mod.rs @@ -1,32 +1,34 @@ -use crate::organization::model::{OrganizationDetails, OrganizationSummary}; -use crate::Error; -use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; -use trustify_common::db::limiter::LimiterTrait; -use trustify_common::db::query::{Filtering, Query}; -use trustify_common::db::{Database, Transactional}; -use trustify_common::model::{Paginated, PaginatedResults}; +use crate::{ + organization::model::{OrganizationDetails, OrganizationSummary}, + Error, +}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter}; +use trustify_common::{ + db::{ + limiter::LimiterTrait, + query::{Filtering, Query}, + }, + model::{Paginated, PaginatedResults}, +}; use trustify_entity::organization; use uuid::Uuid; -pub struct OrganizationService { - db: Database, -} +#[derive(Default)] +pub struct OrganizationService {} impl OrganizationService { - pub fn new(db: Database) -> Self { - Self { db } + pub fn new() -> Self { + Self {} } - pub async fn fetch_organizations + Sync + Send>( + pub async fn fetch_organizations( &self, search: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let limiter = organization::Entity::find().filtering(search)?.limiting( - &connection, + connection, paginated.offset, paginated.limit, ); @@ -35,23 +37,21 @@ impl OrganizationService { Ok(PaginatedResults { total, - items: OrganizationSummary::from_entities(&limiter.fetch().await?, &connection).await?, + items: OrganizationSummary::from_entities(&limiter.fetch().await?).await?, }) } - pub async fn fetch_organization + Sync + Send>( + pub async fn fetch_organization( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(organization) = organization::Entity::find() .filter(organization::Column::Id.eq(id)) - .one(&connection) + .one(connection) .await? { Ok(Some( - OrganizationDetails::from_entity(&organization, &connection).await?, + OrganizationDetails::from_entity(&organization, connection).await?, )) } else { Ok(None) diff --git a/modules/fundamental/src/organization/service/test.rs b/modules/fundamental/src/organization/service/test.rs index 94d70d2ce..2608864b1 100644 --- a/modules/fundamental/src/organization/service/test.rs +++ b/modules/fundamental/src/organization/service/test.rs @@ -24,14 +24,14 @@ async fn all_organizations(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; - let service = crate::organization::service::OrganizationService::new(ctx.db.clone()); + let service = crate::organization::service::OrganizationService::new(); let orgs = service - .fetch_organizations(Query::default(), Paginated::default(), ()) + .fetch_organizations(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, orgs.total); diff --git a/modules/fundamental/src/product/endpoints/mod.rs b/modules/fundamental/src/product/endpoints/mod.rs index 8060f9f8b..2402f9dfd 100644 --- a/modules/fundamental/src/product/endpoints/mod.rs +++ b/modules/fundamental/src/product/endpoints/mod.rs @@ -6,9 +6,10 @@ use crate::{ model::{details::ProductDetails, summary::ProductSummary}, service::ProductService, }, - Error::Internal, + Error, }; use actix_web::{delete, get, web, HttpResponse, Responder}; +use sea_orm::TransactionTrait; use trustify_auth::{authorizer::Require, DeleteMetadata, ReadMetadata}; use trustify_common::{ db::{query::Query, Database}, @@ -17,8 +18,9 @@ use trustify_common::{ use uuid::Uuid; pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) { - let service = ProductService::new(db); + let service = ProductService::new(); config + .app_data(web::Data::new(db)) .app_data(web::Data::new(service)) .service(all) .service(delete) @@ -39,11 +41,12 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d #[get("/v1/product")] pub async fn all( state: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { - Ok(HttpResponse::Ok().json(state.fetch_products(search, paginated, ()).await?)) + Ok(HttpResponse::Ok().json(state.fetch_products(search, paginated, db.as_ref()).await?)) } #[utoipa::path( @@ -60,10 +63,11 @@ pub async fn all( #[get("/v1/product/{id}")] pub async fn get( state: web::Data, + db: web::Data, id: web::Path, _: Require, ) -> actix_web::Result { - let fetched = state.fetch_product(*id, ()).await?; + let fetched = state.fetch_product(*id, db.as_ref()).await?; if let Some(fetched) = fetched { Ok(HttpResponse::Ok().json(fetched)) } else { @@ -85,16 +89,22 @@ pub async fn get( #[delete("/v1/product/{id}")] pub async fn delete( state: web::Data, + db: web::Data, id: web::Path, _: Require, -) -> actix_web::Result { - match state.fetch_product(*id, ()).await? { +) -> Result { + let tx = db.begin().await?; + + match state.fetch_product(*id, &tx).await? { Some(v) => { - let rows_affected = state.delete_product(v.head.id, ()).await?; + let rows_affected = state.delete_product(v.head.id, &tx).await?; match rows_affected { 0 => Ok(HttpResponse::NotFound().finish()), - 1 => Ok(HttpResponse::Ok().json(v)), - _ => Err(Internal("Unexpected number of rows affected".into()).into()), + 1 => { + tx.commit().await?; + Ok(HttpResponse::Ok().json(v)) + } + _ => Err(Error::Internal("Unexpected number of rows affected".into())), } } None => Ok(HttpResponse::NotFound().finish()), diff --git a/modules/fundamental/src/product/endpoints/test.rs b/modules/fundamental/src/product/endpoints/test.rs index 5cb6344b4..cb9e33945 100644 --- a/modules/fundamental/src/product/endpoints/test.rs +++ b/modules/fundamental/src/product/endpoints/test.rs @@ -22,7 +22,7 @@ async fn all_products(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; @@ -33,7 +33,7 @@ async fn all_products(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; @@ -62,14 +62,14 @@ async fn one_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; - let service = crate::product::service::ProductService::new(ctx.db.clone()); + let service = crate::product::service::ProductService::new(); let products = service - .fetch_products(Query::default(), Paginated::default(), ()) + .fetch_products(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, products.total); @@ -102,14 +102,14 @@ async fn delete_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; - let service = crate::product::service::ProductService::new(ctx.db.clone()); + let service = crate::product::service::ProductService::new(); let products = service - .fetch_products(Query::default(), Paginated::default(), ()) + .fetch_products(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, products.total); @@ -126,7 +126,7 @@ async fn delete_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!(response.status(), StatusCode::OK); let products = service - .fetch_products(Query::default(), Paginated::default(), ()) + .fetch_products(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(0, products.total); diff --git a/modules/fundamental/src/product/model/details.rs b/modules/fundamental/src/product/model/details.rs index 7bf472f45..363e14fd6 100644 --- a/modules/fundamental/src/product/model/details.rs +++ b/modules/fundamental/src/product/model/details.rs @@ -2,11 +2,10 @@ use crate::organization::model::OrganizationSummary; use crate::product::model::{ProductHead, ProductVersionHead}; use crate::Error; use itertools::izip; -use sea_orm::LoaderTrait; use sea_orm::ModelTrait; +use sea_orm::{ConnectionTrait, LoaderTrait}; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::labels::Labels; use trustify_entity::{organization, product, product_version, sbom}; use utoipa::ToSchema; @@ -21,22 +20,22 @@ pub struct ProductDetails { } impl ProductDetails { - pub async fn from_entity( + pub async fn from_entity( product: &product::Model, org: Option, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let product_versions = product .find_related(product_version::Entity) .all(tx) .await?; let vendor = if let Some(org) = org { - Some(OrganizationSummary::from_entity(&org, tx).await?) + Some(OrganizationSummary::from_entity(&org).await?) } else { None }; Ok(ProductDetails { - head: ProductHead::from_entity(product, tx).await?, + head: ProductHead::from_entity(product).await?, versions: ProductVersionDetails::from_entities(&product_versions, tx).await?, vendor, }) @@ -54,29 +53,28 @@ impl ProductVersionDetails { pub async fn from_entity( product_version: &product_version::Model, sbom: Option, - tx: &ConnectionOrTransaction<'_>, ) -> Result { let sbom = if let Some(sbom) = sbom { - Some(ProductSbomHead::from_entity(&sbom, tx).await?) + Some(ProductSbomHead::from_entity(&sbom).await?) } else { None }; Ok(ProductVersionDetails { - head: ProductVersionHead::from_entity(product_version, tx).await?, + head: ProductVersionHead::from_entity(product_version).await?, sbom, }) } - pub async fn from_entities( + pub async fn from_entities( product_versions: &[product_version::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut details = Vec::new(); let sboms = product_versions.load_one(sbom::Entity, tx).await?; for (version, sbom) in izip!(product_versions, sboms) { - details.push(ProductVersionDetails::from_entity(version, sbom, tx).await?); + details.push(ProductVersionDetails::from_entity(version, sbom).await?); } Ok(details) @@ -92,10 +90,7 @@ pub struct ProductSbomHead { } impl ProductSbomHead { - pub async fn from_entity( - sbom: &trustify_entity::sbom::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(sbom: &sbom::Model) -> Result { Ok(ProductSbomHead { labels: sbom.labels.clone(), published: sbom.published, diff --git a/modules/fundamental/src/product/model/mod.rs b/modules/fundamental/src/product/model/mod.rs index a129e1d1d..4313ee03c 100644 --- a/modules/fundamental/src/product/model/mod.rs +++ b/modules/fundamental/src/product/model/mod.rs @@ -6,7 +6,6 @@ pub mod details; pub mod summary; use crate::Error; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{product, product_version}; #[derive(Serialize, Deserialize, Debug, Clone, ToSchema)] @@ -18,10 +17,7 @@ pub struct ProductHead { } impl ProductHead { - pub async fn from_entity( - product: &product::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(product: &product::Model) -> Result { Ok(ProductHead { id: product.id, name: product.name.clone(), @@ -45,10 +41,7 @@ pub struct ProductVersionHead { } impl ProductVersionHead { - pub async fn from_entity( - product_version: &product_version::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(product_version: &product_version::Model) -> Result { Ok(ProductVersionHead { id: product_version.id, version: product_version.version.clone(), @@ -58,12 +51,11 @@ impl ProductVersionHead { pub async fn from_entities( product_versions: &[product_version::Model], - tx: &ConnectionOrTransaction<'_>, ) -> Result, Error> { let mut heads = Vec::new(); for entity in product_versions { - heads.push(ProductVersionHead::from_entity(entity, tx).await?); + heads.push(ProductVersionHead::from_entity(entity).await?); } Ok(heads) diff --git a/modules/fundamental/src/product/model/summary.rs b/modules/fundamental/src/product/model/summary.rs index de1b9a386..446da081e 100644 --- a/modules/fundamental/src/product/model/summary.rs +++ b/modules/fundamental/src/product/model/summary.rs @@ -2,9 +2,8 @@ use crate::organization::model::OrganizationSummary; use crate::product::model::{ProductHead, ProductVersionHead}; use crate::Error; use itertools::izip; -use sea_orm::LoaderTrait; +use sea_orm::{ConnectionTrait, LoaderTrait}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{organization, product, product_version}; use utoipa::ToSchema; @@ -22,23 +21,22 @@ impl ProductSummary { product: &product::Model, org: Option, versions: &[product_version::Model], - tx: &ConnectionOrTransaction<'_>, ) -> Result { let vendor = if let Some(org) = org { - Some(OrganizationSummary::from_entity(&org, tx).await?) + Some(OrganizationSummary::from_entity(&org).await?) } else { None }; Ok(ProductSummary { - head: ProductHead::from_entity(product, tx).await?, - versions: ProductVersionHead::from_entities(versions, tx).await?, + head: ProductHead::from_entity(product).await?, + versions: ProductVersionHead::from_entities(versions).await?, vendor, }) } - pub async fn from_entities( + pub async fn from_entities( products: &[product::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let versions = products.load_many(product_version::Entity, tx).await?; let orgs = products.load_one(organization::Entity, tx).await?; @@ -46,7 +44,7 @@ impl ProductSummary { let mut summaries = Vec::new(); for (product, org, version) in izip!(products, orgs, versions) { - summaries.push(ProductSummary::from_entity(product, org, &version, tx).await?); + summaries.push(ProductSummary::from_entity(product, org, &version).await?); } Ok(summaries) diff --git a/modules/fundamental/src/product/service/mod.rs b/modules/fundamental/src/product/service/mod.rs index 7853e68c1..7d12803d7 100644 --- a/modules/fundamental/src/product/service/mod.rs +++ b/modules/fundamental/src/product/service/mod.rs @@ -1,33 +1,32 @@ use super::model::summary::ProductSummary; -use crate::product::model::details::ProductDetails; -use crate::Error; -use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; -use trustify_common::db::limiter::LimiterTrait; -use trustify_common::db::query::{Filtering, Query}; -use trustify_common::db::{Database, Transactional}; -use trustify_common::model::{Paginated, PaginatedResults}; +use crate::{product::model::details::ProductDetails, Error}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter}; +use trustify_common::{ + db::{ + limiter::LimiterTrait, + query::{Filtering, Query}, + }, + model::{Paginated, PaginatedResults}, +}; use trustify_entity::product; use uuid::Uuid; -pub struct ProductService { - db: Database, -} +#[derive(Default)] +pub struct ProductService {} impl ProductService { - pub fn new(db: Database) -> Self { - Self { db } + pub fn new() -> Self { + Self {} } - pub async fn fetch_products + Sync + Send>( + pub async fn fetch_products( &self, search: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let limiter = product::Entity::find().filtering(search)?.limiting( - &connection, + connection, paginated.offset, paginated.limit, ); @@ -36,41 +35,37 @@ impl ProductService { Ok(PaginatedResults { total, - items: ProductSummary::from_entities(&limiter.fetch().await?, &connection).await?, + items: ProductSummary::from_entities(&limiter.fetch().await?, connection).await?, }) } - pub async fn fetch_product + Sync + Send>( + pub async fn fetch_product( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(product) = product::Entity::find() .find_also_related(trustify_entity::organization::Entity) .filter(product::Column::Id.eq(id)) - .one(&connection) + .one(connection) .await? { Ok(Some( - ProductDetails::from_entity(&product.0, product.1, &connection).await?, + ProductDetails::from_entity(&product.0, product.1, connection).await?, )) } else { Ok(None) } } - pub async fn delete_product + Sync + Send>( + pub async fn delete_product( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result { - let connection = self.db.connection(&tx); - let query = product::Entity::delete_by_id(id); - let result = query.exec(&connection).await?; + let result = query.exec(connection).await?; Ok(result.rows_affected) } diff --git a/modules/fundamental/src/product/service/test.rs b/modules/fundamental/src/product/service/test.rs index 30616c07a..0887dfe5d 100644 --- a/modules/fundamental/src/product/service/test.rs +++ b/modules/fundamental/src/product/service/test.rs @@ -3,7 +3,6 @@ use test_context::test_context; use test_log::test; use trustify_common::cpe::Cpe; use trustify_common::db::query::Query; -use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_common::model::Paginated; use trustify_module_ingestor::graph::product::ProductInformation; @@ -19,7 +18,7 @@ async fn all_products(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { &Digests::digest("RHSA-1"), "a", (), - Transactional::None, + &ctx.db, ) .await?; @@ -31,27 +30,24 @@ async fn all_products(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; let ver = pr - .ingest_product_version("1.0.0".to_string(), Some(sbom.sbom.sbom_id), ()) + .ingest_product_version("1.0.0".to_string(), Some(sbom.sbom.sbom_id), &ctx.db) .await?; - let service = crate::product::service::ProductService::new(ctx.db.clone()); + let service = crate::product::service::ProductService::new(); let prods = service - .fetch_products(Query::default(), Paginated::default(), ()) + .fetch_products(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, prods.total); assert_eq!(1, prods.items.len()); - let ver_sbom = ver - .get_sbom(Transactional::None) - .await? - .expect("No sbom found"); + let ver_sbom = ver.get_sbom(&ctx.db).await?.expect("No sbom found"); assert_eq!(ver_sbom.sbom.sbom_id, sbom.sbom.sbom_id); Ok(()) @@ -68,12 +64,12 @@ async fn link_sbom_to_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error vendor: Some("Red Hat".to_string()), cpe: Some(Cpe::from_str("cpe:/a:redhat:tpa:2.0.0")?), }, - (), + &ctx.db, ) .await?; let prv = pr - .ingest_product_version("1.0.0".to_string(), None, ()) + .ingest_product_version("1.0.0".to_string(), None, &ctx.db) .await?; let sbom = ctx @@ -83,21 +79,18 @@ async fn link_sbom_to_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error &Digests::digest("RHSA-1"), "a", (), - Transactional::None, + &ctx.db, ) .await?; - let prv = sbom.link_to_product(prv, Transactional::None).await?; + let prv = sbom.link_to_product(prv, &ctx.db).await?; assert_eq!( sbom.sbom.sbom_id, prv.product_version.sbom_id.expect("no sbom") ); - let product = sbom - .get_product(Transactional::None) - .await? - .expect("No product"); + let product = sbom.get_product(&ctx.db).await?.expect("No product"); assert_eq!("Trusted Profile Analyzer", product.product.product.name); assert_eq!("1.0.0", product.product_version.version); @@ -112,7 +105,7 @@ async fn link_sbom_to_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error let org = product .product - .get_vendor(Transactional::None) + .get_vendor(&ctx.db) .await? .expect("no organization"); assert_eq!("Red Hat", org.organization.name); @@ -132,23 +125,23 @@ async fn delete_product(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { vendor: Some("Red Hat".to_string()), cpe: None, }, - (), + &ctx.db, ) .await?; - let service = crate::product::service::ProductService::new(ctx.db.clone()); + let service = crate::product::service::ProductService::new(); let prods = service - .fetch_products(Query::default(), Paginated::default(), ()) + .fetch_products(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, prods.total); assert_eq!(1, prods.items.len()); - let result = service.delete_product(pr.product.id, ()).await?; + let result = service.delete_product(pr.product.id, &ctx.db).await?; assert_eq!(1, result); - let result = service.delete_product(pr.product.id, ()).await?; + let result = service.delete_product(pr.product.id, &ctx.db).await?; assert_eq!(0, result); Ok(()) diff --git a/modules/fundamental/src/purl/endpoints/base.rs b/modules/fundamental/src/purl/endpoints/base.rs index 7f4375efc..78362ae63 100644 --- a/modules/fundamental/src/purl/endpoints/base.rs +++ b/modules/fundamental/src/purl/endpoints/base.rs @@ -10,7 +10,10 @@ use sea_orm::prelude::Uuid; use std::str::FromStr; use trustify_auth::{authorizer::Require, ReadSbom}; use trustify_common::{ - db::query::Query, id::IdError, model::Paginated, model::PaginatedResults, purl::Purl, + db::{query::Query, Database}, + id::IdError, + model::{Paginated, PaginatedResults}, + purl::Purl, }; #[utoipa::path( @@ -27,15 +30,16 @@ use trustify_common::{ /// Retrieve details about a base versionless pURL pub async fn get_base_purl( service: web::Data, + db: web::Data, key: web::Path, _: Require, ) -> actix_web::Result { if key.starts_with("pkg:") { let purl = Purl::from_str(&key).map_err(|e| Error::IdKey(IdError::Purl(e)))?; - Ok(HttpResponse::Ok().json(service.base_purl_by_purl(&purl, ()).await?)) + Ok(HttpResponse::Ok().json(service.base_purl_by_purl(&purl, db.as_ref()).await?)) } else { let uuid = Uuid::from_str(&key).map_err(|e| Error::IdKey(IdError::InvalidUuid(e)))?; - Ok(HttpResponse::Ok().json(service.base_purl_by_uuid(&uuid, ()).await?)) + Ok(HttpResponse::Ok().json(service.base_purl_by_uuid(&uuid, db.as_ref()).await?)) } } @@ -54,8 +58,9 @@ pub async fn get_base_purl( /// List base versionless pURLs pub async fn all_base_purls( service: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, ) -> actix_web::Result { - Ok(HttpResponse::Ok().json(service.base_purls(search, paginated, ()).await?)) + Ok(HttpResponse::Ok().json(service.base_purls(search, paginated, db.as_ref()).await?)) } diff --git a/modules/fundamental/src/purl/endpoints/mod.rs b/modules/fundamental/src/purl/endpoints/mod.rs index 357b33261..561e97b6e 100644 --- a/modules/fundamental/src/purl/endpoints/mod.rs +++ b/modules/fundamental/src/purl/endpoints/mod.rs @@ -20,9 +20,10 @@ mod r#type; mod version; pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) { - let purl_service = PurlService::new(db); + let purl_service = PurlService::new(); config + .app_data(web::Data::new(db)) .app_data(web::Data::new(purl_service)) .service(r#type::all_purl_types) .service(r#type::get_purl_type) @@ -50,16 +51,17 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d /// Retrieve details of a fully-qualified pURL pub async fn get( service: web::Data, + db: web::Data, key: web::Path, web::Query(Deprecation { deprecated }): web::Query, _: Require, ) -> actix_web::Result { if key.starts_with("pkg") { let purl = Purl::from_str(&key).map_err(Error::Purl)?; - Ok(HttpResponse::Ok().json(service.purl_by_purl(&purl, deprecated, ()).await?)) + Ok(HttpResponse::Ok().json(service.purl_by_purl(&purl, deprecated, db.as_ref()).await?)) } else { let id = Uuid::from_str(&key).map_err(|e| Error::IdKey(IdError::InvalidUuid(e)))?; - Ok(HttpResponse::Ok().json(service.purl_by_uuid(&id, deprecated, ()).await?)) + Ok(HttpResponse::Ok().json(service.purl_by_uuid(&id, deprecated, db.as_ref()).await?)) } } @@ -78,11 +80,12 @@ pub async fn get( /// List fully-qualified pURLs pub async fn all( service: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { - Ok(HttpResponse::Ok().json(service.purls(search, paginated, ()).await?)) + Ok(HttpResponse::Ok().json(service.purls(search, paginated, db.as_ref()).await?)) } #[cfg(test)] diff --git a/modules/fundamental/src/purl/endpoints/test.rs b/modules/fundamental/src/purl/endpoints/test.rs index 65804cfa5..363380b0c 100644 --- a/modules/fundamental/src/purl/endpoints/test.rs +++ b/modules/fundamental/src/purl/endpoints/test.rs @@ -10,59 +10,59 @@ use serde_json::Value; use std::str::FromStr; use test_context::test_context; use test_log::test; -use trustify_common::db::Transactional; +use trustify_common::db::Database; use trustify_common::model::PaginatedResults; use trustify_common::purl::Purl; use trustify_module_ingestor::graph::Graph; use trustify_test_context::{call::CallService, TrustifyContext}; -async fn setup(graph: &Graph) -> Result<(), anyhow::Error> { +async fn setup(db: &Database, graph: &Graph) -> Result<(), anyhow::Error> { let log4j = graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, db) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=11")?, - (), + db, ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=17")?, - (), + db, ) .await?; let log4j_345 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, ()) + .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, db) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + db, ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + db, ) .await?; let sendmail = graph - .ingest_package(&Purl::from_str("pkg:rpm/sendmail")?, ()) + .ingest_package(&Purl::from_str("pkg:rpm/sendmail")?, db) .await?; let _sendmail_444 = sendmail - .ingest_package_version(&Purl::from_str("pkg:rpm/sendmail@4.4.4")?, ()) + .ingest_package_version(&Purl::from_str("pkg:rpm/sendmail@4.4.4")?, db) .await?; Ok(()) @@ -71,7 +71,7 @@ async fn setup(graph: &Graph) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn types(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type"; @@ -99,7 +99,7 @@ async fn types(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn r#type(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type/maven"; @@ -130,7 +130,7 @@ async fn r#type(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn type_package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type/maven/org.apache/log4j"; @@ -161,7 +161,7 @@ async fn type_package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn type_package_version(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type/maven/org.apache/log4j@1.2.3"; @@ -188,7 +188,7 @@ async fn type_package_version(ctx: &TrustifyContext) -> Result<(), anyhow::Error #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type/maven/org.apache/log4j@1.2.3"; @@ -218,7 +218,7 @@ async fn package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn version(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type/maven/org.apache/log4j@1.2.3"; @@ -238,7 +238,7 @@ async fn version(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn base(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/type/maven/org.apache/log4j"; @@ -257,7 +257,7 @@ async fn base(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn base_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl/base?q=log4j"; @@ -272,7 +272,7 @@ async fn base_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn qualified_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl?q=log4j"; @@ -287,7 +287,7 @@ async fn qualified_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn qualified_packages_filtering(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - setup(&ctx.graph).await?; + setup(&ctx.db, &ctx.graph).await?; let app = caller(ctx).await?; let uri = "/api/v1/purl?q=type%3Dmaven"; @@ -304,10 +304,7 @@ async fn qualified_packages_filtering(ctx: &TrustifyContext) -> Result<(), anyho async fn package_with_status(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.ingestor .graph() - .ingest_qualified_package( - &Purl::from_str("pkg:cargo/hyper@0.14.1")?, - Transactional::None, - ) + .ingest_qualified_package(&Purl::from_str("pkg:cargo/hyper@0.14.1")?, &ctx.db) .await?; ctx.ingest_documents(["osv/RUSTSEC-2021-0079.json", "cve/CVE-2021-32714.json"]) diff --git a/modules/fundamental/src/purl/endpoints/type.rs b/modules/fundamental/src/purl/endpoints/type.rs index fbb8f6b6c..ca987f6ea 100644 --- a/modules/fundamental/src/purl/endpoints/type.rs +++ b/modules/fundamental/src/purl/endpoints/type.rs @@ -7,7 +7,10 @@ use crate::purl::{ }; use actix_web::{get, web, HttpResponse, Responder}; use trustify_auth::{authorizer::Require, ReadSbom}; -use trustify_common::{db::query::Query, model::Paginated, model::PaginatedResults}; +use trustify_common::{ + db::{query::Query, Database}, + model::{Paginated, PaginatedResults}, +}; #[utoipa::path( tag = "purl type", @@ -22,9 +25,10 @@ use trustify_common::{db::query::Query, model::Paginated, model::PaginatedResult /// List known pURL types pub async fn all_purl_types( service: web::Data, + db: web::Data, _: Require, ) -> actix_web::Result { - Ok(HttpResponse::Ok().json(service.purl_types(()).await?)) + Ok(HttpResponse::Ok().json(service.purl_types(db.as_ref()).await?)) } #[utoipa::path( @@ -43,6 +47,7 @@ pub async fn all_purl_types( /// Retrieve details about a pURL type pub async fn get_purl_type( service: web::Data, + db: web::Data, r#type: web::Path, web::Query(search): web::Query, web::Query(paginated): web::Query, @@ -50,7 +55,7 @@ pub async fn get_purl_type( ) -> actix_web::Result { Ok(HttpResponse::Ok().json( service - .base_purls_by_type(&r#type, search, paginated, ()) + .base_purls_by_type(&r#type, search, paginated, db.as_ref()) .await?, )) } @@ -71,6 +76,7 @@ pub async fn get_purl_type( /// Retrieve base pURL details of a type pub async fn get_base_purl_of_type( service: web::Data, + db: web::Data, path: web::Path<(String, String)>, _: Require, ) -> actix_web::Result { @@ -82,7 +88,11 @@ pub async fn get_base_purl_of_type( (None, namespace_and_name) }; - Ok(HttpResponse::Ok().json(service.base_purl(&r#type, namespace, &name, ()).await?)) + Ok(HttpResponse::Ok().json( + service + .base_purl(&r#type, namespace, &name, db.as_ref()) + .await?, + )) } #[utoipa::path( @@ -101,6 +111,7 @@ pub async fn get_base_purl_of_type( /// Retrieve versioned pURL details of a type pub async fn get_versioned_purl_of_type( service: web::Data, + db: web::Data, path: web::Path<(String, String, String)>, _: Require, ) -> actix_web::Result { @@ -114,7 +125,7 @@ pub async fn get_versioned_purl_of_type( Ok(HttpResponse::Ok().json( service - .versioned_purl(&r#type, namespace, &name, &version, ()) + .versioned_purl(&r#type, namespace, &name, &version, db.as_ref()) .await?, )) } diff --git a/modules/fundamental/src/purl/endpoints/version.rs b/modules/fundamental/src/purl/endpoints/version.rs index dbacc074e..b5aaf1a81 100644 --- a/modules/fundamental/src/purl/endpoints/version.rs +++ b/modules/fundamental/src/purl/endpoints/version.rs @@ -5,9 +5,8 @@ use crate::{ use actix_web::{get, web, HttpResponse, Responder}; use sea_orm::prelude::Uuid; use std::str::FromStr; -use trustify_auth::authorizer::Require; -use trustify_auth::ReadSbom; -use trustify_common::{id::IdError, purl::Purl}; +use trustify_auth::{authorizer::Require, ReadSbom}; +use trustify_common::{db::Database, id::IdError, purl::Purl}; #[utoipa::path( tag = "versioned purl", @@ -23,14 +22,15 @@ use trustify_common::{id::IdError, purl::Purl}; /// Retrieve details of a versioned, non-qualified pURL pub async fn get_versioned_purl( service: web::Data, + db: web::Data, key: web::Path, _: Require, ) -> actix_web::Result { if key.starts_with("pkg:") { let purl = Purl::from_str(&key).map_err(|e| Error::IdKey(IdError::Purl(e)))?; - Ok(HttpResponse::Ok().json(service.versioned_purl_by_purl(&purl, ()).await?)) + Ok(HttpResponse::Ok().json(service.versioned_purl_by_purl(&purl, db.as_ref()).await?)) } else { let uuid = Uuid::from_str(&key).map_err(|e| Error::IdKey(IdError::InvalidUuid(e)))?; - Ok(HttpResponse::Ok().json(service.versioned_purl_by_uuid(&uuid, ()).await?)) + Ok(HttpResponse::Ok().json(service.versioned_purl_by_uuid(&uuid, db.as_ref()).await?)) } } diff --git a/modules/fundamental/src/purl/model/details/base_purl.rs b/modules/fundamental/src/purl/model/details/base_purl.rs index d96a1b06c..8ea62bf97 100644 --- a/modules/fundamental/src/purl/model/details/base_purl.rs +++ b/modules/fundamental/src/purl/model/details/base_purl.rs @@ -1,9 +1,8 @@ use crate::purl::model::summary::versioned_purl::VersionedPurlSummary; use crate::purl::model::BasePurlHead; use crate::Error; -use sea_orm::ModelTrait; +use sea_orm::{ConnectionTrait, ModelTrait}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{base_purl, versioned_purl}; use utoipa::ToSchema; @@ -15,14 +14,14 @@ pub struct BasePurlDetails { } impl BasePurlDetails { - pub async fn from_entity( + pub async fn from_entity( package: &base_purl::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let package_versions = package.find_related(versioned_purl::Entity).all(tx).await?; Ok(Self { - head: BasePurlHead::from_entity(package, tx).await?, + head: BasePurlHead::from_entity(package).await?, versions: VersionedPurlSummary::from_entities_with_common_package( package, &package_versions, diff --git a/modules/fundamental/src/purl/model/details/purl.rs b/modules/fundamental/src/purl/model/details/purl.rs index 2dd8b0866..8587ad209 100644 --- a/modules/fundamental/src/purl/model/details/purl.rs +++ b/modules/fundamental/src/purl/model/details/purl.rs @@ -1,24 +1,25 @@ -use crate::sbom::model::SbomHead; use crate::{ advisory::model::AdvisoryHead, purl::model::{BasePurlHead, PurlHead, VersionedPurlHead}, + sbom::model::SbomHead, vulnerability::model::VulnerabilityHead, Error, }; use ::cpe::uri::OwnedUri; use sea_orm::{ - ColumnTrait, DbErr, EntityTrait, FromQueryResult, LoaderTrait, ModelTrait, QueryFilter, - QueryOrder, QueryResult, QuerySelect, QueryTrait, RelationTrait, Select, + ColumnTrait, ConnectionTrait, DbErr, EntityTrait, FromQueryResult, LoaderTrait, ModelTrait, + QueryFilter, QueryOrder, QueryResult, QuerySelect, QueryTrait, RelationTrait, Select, }; use sea_query::{Asterisk, ColumnRef, Expr, Func, IntoIden, JoinType, SimpleExpr}; use serde::{Deserialize, Serialize}; -use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::{collections::hash_map::Entry, collections::HashMap}; use strum::IntoEnumIterator; -use trustify_common::db::multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}; -use trustify_common::db::{ConnectionOrTransaction, VersionMatches}; -use trustify_common::memo::Memo; -use trustify_common::purl::Purl; +use trustify_common::{ + db::multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, + db::VersionMatches, + memo::Memo, + purl::Purl, +}; use trustify_entity::{ advisory, base_purl, cpe, license, organization, product, product_status, product_version, product_version_range, purl_license_assertion, purl_status, qualified_purl, sbom, sbom_package, @@ -39,12 +40,12 @@ pub struct PurlDetails { } impl PurlDetails { - pub async fn from_entity( + pub async fn from_entity( package: Option, package_version: Option, qualified_package: &qualified_purl::Model, deprecation: Deprecation, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let package_version = if let Some(package_version) = package_version { package_version @@ -115,15 +116,15 @@ impl PurlDetails { Ok(PurlDetails { head: PurlHead::from_entity(&package, &package_version, qualified_package, tx).await?, version: VersionedPurlHead::from_entity(&package, &package_version, tx).await?, - base: BasePurlHead::from_entity(&package, tx).await?, + base: BasePurlHead::from_entity(&package).await?, advisories: PurlAdvisory::from_entities(purl_statuses, product_statuses, tx).await?, licenses: PurlLicenseSummary::from_entities(&licenses, tx).await?, }) } } -async fn get_product_statuses_for_purl( - tx: &ConnectionOrTransaction<'_>, +async fn get_product_statuses_for_purl( + tx: &C, qualified_package_id: Uuid, purl_name: &str, namespace_name: Option<&str>, @@ -188,10 +189,10 @@ pub struct PurlAdvisory { } impl PurlAdvisory { - pub async fn from_entities( + pub async fn from_entities( purl_statuses: Vec, product_statuses: Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let vulns = purl_statuses.load_one(vulnerability::Entity, tx).await?; @@ -306,10 +307,10 @@ pub enum StatusContext { } impl PurlStatus { - pub async fn from_entity( + pub async fn from_entity( vuln: &vulnerability::Model, package_status: &purl_status::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let status = status::Entity::find_by_id(package_status.status_id) .one(tx) @@ -345,9 +346,9 @@ pub struct PurlLicenseSummary { } impl PurlLicenseSummary { - pub async fn from_entities( + pub async fn from_entities( entities: &[LicenseCatcher], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut summaries = HashMap::new(); diff --git a/modules/fundamental/src/purl/model/details/versioned_purl.rs b/modules/fundamental/src/purl/model/details/versioned_purl.rs index cf26aa92d..e09ba5fbe 100644 --- a/modules/fundamental/src/purl/model/details/versioned_purl.rs +++ b/modules/fundamental/src/purl/model/details/versioned_purl.rs @@ -5,12 +5,12 @@ use crate::{ Error, }; use sea_orm::{ - ColumnTrait, EntityTrait, LoaderTrait, ModelTrait, QueryFilter, QuerySelect, RelationTrait, + ColumnTrait, ConnectionTrait, EntityTrait, LoaderTrait, ModelTrait, QueryFilter, QuerySelect, + RelationTrait, }; use sea_query::{Asterisk, Expr, Func, JoinType, SimpleExpr}; use serde::{Deserialize, Serialize}; -use trustify_common::db::{ConnectionOrTransaction, VersionMatches}; -use trustify_common::memo::Memo; +use trustify_common::{db::VersionMatches, memo::Memo}; use trustify_entity::{ advisory, base_purl, organization, purl_status, qualified_purl, status, version_range, versioned_purl, vulnerability, @@ -27,10 +27,10 @@ pub struct VersionedPurlDetails { } impl VersionedPurlDetails { - pub async fn from_entity( + pub async fn from_entity( package: Option, package_version: &versioned_purl::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let package = if let Some(package) = package { package @@ -75,7 +75,7 @@ impl VersionedPurlDetails { Ok(Self { head: VersionedPurlHead::from_entity(&package, package_version, tx).await?, - base: BasePurlHead::from_entity(&package, tx).await?, + base: BasePurlHead::from_entity(&package).await?, purls: qualified_packages, advisories: VersionedPurlAdvisory::from_entities(statuses, tx).await?, }) @@ -90,9 +90,9 @@ pub struct VersionedPurlAdvisory { } impl VersionedPurlAdvisory { - pub async fn from_entities( + pub async fn from_entities( statuses: Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let vulns = statuses.load_one(vulnerability::Entity, tx).await?; @@ -134,10 +134,10 @@ pub struct VersionedPurlStatus { } impl VersionedPurlStatus { - pub async fn from_entity( + pub async fn from_entity( vuln: &vulnerability::Model, package_status: &purl_status::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let status = package_status.find_related(status::Entity).one(tx).await?; diff --git a/modules/fundamental/src/purl/model/mod.rs b/modules/fundamental/src/purl/model/mod.rs index ad906bd72..629e87523 100644 --- a/modules/fundamental/src/purl/model/mod.rs +++ b/modules/fundamental/src/purl/model/mod.rs @@ -1,7 +1,7 @@ use crate::Error; use sea_orm::prelude::Uuid; +use sea_orm::ConnectionTrait; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::purl::Purl; use trustify_entity::{base_purl, qualified_purl, versioned_purl}; use utoipa::ToSchema; @@ -18,10 +18,7 @@ pub struct BasePurlHead { } impl BasePurlHead { - pub async fn from_entity( - entity: &base_purl::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(entity: &base_purl::Model) -> Result { Ok(BasePurlHead { uuid: entity.id, purl: Purl { @@ -36,12 +33,11 @@ impl BasePurlHead { pub async fn from_package_entities( entities: &Vec, - tx: &ConnectionOrTransaction<'_>, ) -> Result, Error> { let mut heads = Vec::new(); for entity in entities { - heads.push(Self::from_entity(entity, tx).await?) + heads.push(Self::from_entity(entity).await?) } Ok(heads) @@ -59,10 +55,10 @@ pub struct VersionedPurlHead { } impl VersionedPurlHead { - pub async fn from_entity( + pub async fn from_entity( package: &base_purl::Model, package_version: &versioned_purl::Model, - _tx: &ConnectionOrTransaction<'_>, + _db: &C, ) -> Result { Ok(Self { uuid: package_version.id, @@ -87,11 +83,11 @@ pub struct PurlHead { } impl PurlHead { - pub async fn from_entity( + pub async fn from_entity( package: &base_purl::Model, package_version: &versioned_purl::Model, qualified_package: &qualified_purl::Model, - _tx: &ConnectionOrTransaction<'_>, + _db: &C, ) -> Result { Ok(Self { uuid: qualified_package.id, @@ -105,11 +101,11 @@ impl PurlHead { }) } - pub async fn from_entities( + pub async fn from_entities( package: &base_purl::Model, package_version: &versioned_purl::Model, qualified_packages: &Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut heads = Vec::new(); diff --git a/modules/fundamental/src/purl/model/summary/base_purl.rs b/modules/fundamental/src/purl/model/summary/base_purl.rs index 9e05f2aa4..0485b960b 100644 --- a/modules/fundamental/src/purl/model/summary/base_purl.rs +++ b/modules/fundamental/src/purl/model/summary/base_purl.rs @@ -1,7 +1,6 @@ use crate::purl::model::BasePurlHead; use crate::Error; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::base_purl; use utoipa::ToSchema; @@ -12,15 +11,12 @@ pub struct BasePurlSummary { } impl BasePurlSummary { - pub async fn from_entities( - entities: &Vec, - tx: &ConnectionOrTransaction<'_>, - ) -> Result, Error> { + pub async fn from_entities(entities: &Vec) -> Result, Error> { let mut summaries = Vec::new(); for entity in entities { summaries.push(BasePurlSummary { - head: BasePurlHead::from_entity(entity, tx).await?, + head: BasePurlHead::from_entity(entity).await?, }) } diff --git a/modules/fundamental/src/purl/model/summary/purl.rs b/modules/fundamental/src/purl/model/summary/purl.rs index e8bb73ec3..c0f33cfe2 100644 --- a/modules/fundamental/src/purl/model/summary/purl.rs +++ b/modules/fundamental/src/purl/model/summary/purl.rs @@ -1,9 +1,8 @@ use crate::purl::model::{BasePurlHead, PurlHead, VersionedPurlHead}; use crate::Error; -use sea_orm::{LoaderTrait, ModelTrait}; +use sea_orm::{ConnectionTrait, LoaderTrait, ModelTrait}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{base_purl, qualified_purl, versioned_purl}; use utoipa::ToSchema; @@ -17,9 +16,9 @@ pub struct PurlSummary { } impl PurlSummary { - pub async fn from_entities( + pub async fn from_entities( qualified_packages: &Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let package_versions = qualified_packages .load_one(versioned_purl::Entity, tx) @@ -45,7 +44,7 @@ impl PurlSummary { tx, ) .await?, - base: BasePurlHead::from_entity(&package, tx).await?, + base: BasePurlHead::from_entity(&package).await?, version: VersionedPurlHead::from_entity(&package, package_version, tx) .await?, qualifiers: qualified_package.qualifiers.0.clone(), @@ -57,16 +56,16 @@ impl PurlSummary { Ok(summaries) } - pub async fn from_entity( + pub async fn from_entity( base_purl: &base_purl::Model, versioned_purl: &versioned_purl::Model, purl: &qualified_purl::Model, - tx: &ConnectionOrTransaction<'_>, + db: &C, ) -> Result { Ok(PurlSummary { - head: PurlHead::from_entity(base_purl, versioned_purl, purl, tx).await?, - base: BasePurlHead::from_entity(base_purl, tx).await?, - version: VersionedPurlHead::from_entity(base_purl, versioned_purl, tx).await?, + head: PurlHead::from_entity(base_purl, versioned_purl, purl, db).await?, + base: BasePurlHead::from_entity(base_purl).await?, + version: VersionedPurlHead::from_entity(base_purl, versioned_purl, db).await?, qualifiers: purl.qualifiers.0.clone(), }) } diff --git a/modules/fundamental/src/purl/model/summary/type.rs b/modules/fundamental/src/purl/model/summary/type.rs index 3f16ae194..77ba3b4cc 100644 --- a/modules/fundamental/src/purl/model/summary/type.rs +++ b/modules/fundamental/src/purl/model/summary/type.rs @@ -1,8 +1,9 @@ use crate::purl::model::TypeHead; use crate::Error; -use sea_orm::{ColumnTrait, DeriveColumn, EntityTrait, EnumIter, QueryFilter, QuerySelect}; +use sea_orm::{ + ColumnTrait, ConnectionTrait, DeriveColumn, EntityTrait, EnumIter, QueryFilter, QuerySelect, +}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{base_purl, qualified_purl, versioned_purl}; use utoipa::ToSchema; @@ -21,9 +22,9 @@ pub struct TypeCounts { } impl TypeSummary { - pub async fn from_names( + pub async fn from_names( names: &Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryAs { diff --git a/modules/fundamental/src/purl/model/summary/versioned_purl.rs b/modules/fundamental/src/purl/model/summary/versioned_purl.rs index d261ca9e1..d0c392760 100644 --- a/modules/fundamental/src/purl/model/summary/versioned_purl.rs +++ b/modules/fundamental/src/purl/model/summary/versioned_purl.rs @@ -1,8 +1,7 @@ use crate::purl::model::{BasePurlHead, PurlHead, VersionedPurlHead}; use crate::Error; -use sea_orm::LoaderTrait; +use sea_orm::{ConnectionTrait, LoaderTrait}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::{base_purl, qualified_purl, versioned_purl}; use utoipa::ToSchema; @@ -15,10 +14,10 @@ pub struct VersionedPurlSummary { } impl VersionedPurlSummary { - pub async fn from_entities_with_common_package( + pub async fn from_entities_with_common_package( package: &base_purl::Model, package_versions: &Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut summaries = Vec::new(); @@ -31,7 +30,7 @@ impl VersionedPurlSummary { { summaries.push(Self { head: VersionedPurlHead::from_entity(package, package_version, tx).await?, - base: BasePurlHead::from_entity(package, tx).await?, + base: BasePurlHead::from_entity(package).await?, purls: PurlHead::from_entities(package, package_version, qualified_packages, tx) .await?, }) @@ -40,14 +39,14 @@ impl VersionedPurlSummary { Ok(summaries) } - pub async fn from_entity( + pub async fn from_entity( base_purl: &base_purl::Model, versioned_purl: &versioned_purl::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { Ok(Self { head: VersionedPurlHead::from_entity(base_purl, versioned_purl, tx).await?, - base: BasePurlHead::from_entity(base_purl, tx).await?, + base: BasePurlHead::from_entity(base_purl).await?, purls: vec![], }) } diff --git a/modules/fundamental/src/purl/service/mod.rs b/modules/fundamental/src/purl/service/mod.rs index 78fddf776..ce32e4e6e 100644 --- a/modules/fundamental/src/purl/service/mod.rs +++ b/modules/fundamental/src/purl/service/mod.rs @@ -17,7 +17,6 @@ use trustify_common::{ db::{ limiter::LimiterTrait, query::{Filtering, Query}, - Database, Transactional, }, model::{Paginated, PaginatedResults}, purl::{Purl, PurlErr}, @@ -25,26 +24,23 @@ use trustify_common::{ use trustify_entity::{base_purl, qualified_purl, versioned_purl}; use trustify_module_ingestor::common::Deprecation; -pub struct PurlService { - db: Database, -} +#[derive(Default)] +pub struct PurlService {} impl PurlService { - pub fn new(db: Database) -> Self { - Self { db } + pub fn new() -> Self { + Self {} } - pub async fn purl_types>( + pub async fn purl_types( &self, - tx: TX, + connection: &C, ) -> Result, Error> { #[derive(FromQueryResult)] struct Ecosystem { r#type: String, } - let connection = self.db.connection(&tx); - let ecosystems: Vec<_> = base_purl::Entity::find() .select_only() .column(base_purl::Column::Type) @@ -52,46 +48,42 @@ impl PurlService { .distinct() .order_by(base_purl::Column::Type, Order::Asc) .into_model::() - .all(&connection) + .all(connection) .await? .into_iter() .map(|e| e.r#type) .collect(); - TypeSummary::from_names(&ecosystems, &connection).await + TypeSummary::from_names(&ecosystems, connection).await } - pub async fn base_purls_by_type>( + pub async fn base_purls_by_type( &self, r#type: &str, query: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let limiter = base_purl::Entity::find() .filter(base_purl::Column::Type.eq(r#type)) .filtering(query)? - .limiting(&connection, paginated.offset, paginated.limit); + .limiting(connection, paginated.offset, paginated.limit); let total = limiter.total().await?; Ok(PaginatedResults { - items: BasePurlSummary::from_entities(&limiter.fetch().await?, &connection).await?, + items: BasePurlSummary::from_entities(&limiter.fetch().await?).await?, total, }) } - pub async fn base_purl>( + pub async fn base_purl( &self, r#type: &str, namespace: Option, name: &str, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let mut query = base_purl::Entity::find() .filter(base_purl::Column::Type.eq(r#type)) .filter(base_purl::Column::Name.eq(name)); @@ -102,25 +94,23 @@ impl PurlService { query = query.filter(base_purl::Column::Namespace.is_null()); } - if let Some(package) = query.one(&connection).await? { + if let Some(package) = query.one(connection).await? { Ok(Some( - BasePurlDetails::from_entity(&package, &connection).await?, + BasePurlDetails::from_entity(&package, connection).await?, )) } else { Ok(None) } } - pub async fn versioned_purl>( + pub async fn versioned_purl( &self, r#type: &str, namespace: Option, name: &str, version: &str, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let mut query = versioned_purl::Entity::find() .left_join(base_purl::Entity) .filter(base_purl::Column::Type.eq(r#type)) @@ -133,42 +123,39 @@ impl PurlService { query = query.filter(base_purl::Column::Namespace.is_null()); } - let package_version = query.one(&connection).await?; + let package_version = query.one(connection).await?; if let Some(package_version) = package_version { Ok(Some( - VersionedPurlDetails::from_entity(None, &package_version, &connection).await?, + VersionedPurlDetails::from_entity(None, &package_version, connection).await?, )) } else { Ok(None) } } - pub async fn base_purl_by_uuid>( + pub async fn base_purl_by_uuid( &self, base_purl_uuid: &Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(package) = base_purl::Entity::find_by_id(*base_purl_uuid) - .one(&connection) + .one(connection) .await? { Ok(Some( - BasePurlDetails::from_entity(&package, &connection).await?, + BasePurlDetails::from_entity(&package, connection).await?, )) } else { Ok(None) } } - pub async fn base_purl_by_purl>( + pub async fn base_purl_by_purl( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); let mut query = base_purl::Entity::find() .filter(base_purl::Column::Type.eq(&purl.ty)) .filter(base_purl::Column::Name.eq(&purl.name)); @@ -179,41 +166,37 @@ impl PurlService { query = query.filter(base_purl::Column::Namespace.is_null()); } - if let Some(base_purl) = query.one(&connection).await? { + if let Some(base_purl) = query.one(connection).await? { Ok(Some( - BasePurlDetails::from_entity(&base_purl, &connection).await?, + BasePurlDetails::from_entity(&base_purl, connection).await?, )) } else { Ok(None) } } - pub async fn versioned_purl_by_uuid>( + pub async fn versioned_purl_by_uuid( &self, purl_version_uuid: &Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(package_version) = versioned_purl::Entity::find_by_id(*purl_version_uuid) - .one(&connection) + .one(connection) .await? { Ok(Some( - VersionedPurlDetails::from_entity(None, &package_version, &connection).await?, + VersionedPurlDetails::from_entity(None, &package_version, connection).await?, )) } else { Ok(None) } } - pub async fn versioned_purl_by_purl>( + pub async fn versioned_purl_by_purl( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(version) = &purl.version { let mut query = versioned_purl::Entity::find() .left_join(base_purl::Entity) @@ -227,11 +210,11 @@ impl PurlService { query = query.filter(base_purl::Column::Namespace.is_null()); } - let package_version = query.one(&connection).await?; + let package_version = query.one(connection).await?; if let Some(package_version) = package_version { Ok(Some( - VersionedPurlDetails::from_entity(None, &package_version, &connection).await?, + VersionedPurlDetails::from_entity(None, &package_version, connection).await?, )) } else { Ok(None) @@ -243,13 +226,12 @@ impl PurlService { } } - pub async fn purl_by_purl>( + pub async fn purl_by_purl( &self, purl: &Purl, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); if let Some(version) = &purl.version { let mut query = qualified_purl::Entity::find() .left_join(versioned_purl::Entity) @@ -264,11 +246,11 @@ impl PurlService { query = query.filter(base_purl::Column::Namespace.is_null()); } - let purl = query.one(&connection).await?; + let purl = query.one(connection).await?; if let Some(purl) = purl { Ok(Some( - PurlDetails::from_entity(None, None, &purl, deprecation, &connection).await?, + PurlDetails::from_entity(None, None, &purl, deprecation, connection).await?, )) } else { Ok(None) @@ -280,21 +262,19 @@ impl PurlService { } } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn purl_by_uuid>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn purl_by_uuid( &self, purl_uuid: &Uuid, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(qualified_package) = qualified_purl::Entity::find_by_id(*purl_uuid) - .one(&connection) + .one(connection) .await? { Ok(Some( - PurlDetails::from_entity(None, None, &qualified_package, deprecation, &connection) + PurlDetails::from_entity(None, None, &qualified_package, deprecation, connection) .await?, )) } else { @@ -302,16 +282,14 @@ impl PurlService { } } - pub async fn base_purls>( + pub async fn base_purls( &self, query: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let limiter = base_purl::Entity::find().filtering(query)?.limiting( - &connection, + connection, paginated.offset, paginated.limit, ); @@ -319,20 +297,18 @@ impl PurlService { let total = limiter.total().await?; Ok(PaginatedResults { - items: BasePurlSummary::from_entities(&limiter.fetch().await?, &connection).await?, + items: BasePurlSummary::from_entities(&limiter.fetch().await?).await?, total, }) } - #[instrument(skip(self, tx), err)] - pub async fn purls>( + #[instrument(skip(self, connection), err)] + pub async fn purls( &self, query: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - // TODO: this would be the condition used to select from jsonb name key let _unused_condition = Expr::cust_with_exprs( "$1->>'name' ~~* $2", @@ -356,20 +332,18 @@ impl PurlService { ), ), ) - .limiting(&connection, paginated.offset, paginated.limit); + .limiting(connection, paginated.offset, paginated.limit); let total = limiter.total().await?; Ok(PaginatedResults { - items: PurlSummary::from_entities(&limiter.fetch().await?, &connection).await?, + items: PurlSummary::from_entities(&limiter.fetch().await?, connection).await?, total, }) } - #[instrument(skip(self, tx), err)] - pub async fn gc_purls>(&self, tx: TX) -> Result { - let connection = self.db.connection(&tx); - + #[instrument(skip(self, connection), err)] + pub async fn gc_purls(&self, connection: &C) -> Result { let res = connection .execute_unprepared(include_str!("gc_purls.sql")) .await?; diff --git a/modules/fundamental/src/purl/service/test.rs b/modules/fundamental/src/purl/service/test.rs index bc8d21425..f159fdfef 100644 --- a/modules/fundamental/src/purl/service/test.rs +++ b/modules/fundamental/src/purl/service/test.rs @@ -7,10 +7,7 @@ use std::str::FromStr; use test_context::test_context; use test_log::test; use trustify_common::{ - db::{ - query::{q, Query}, - Transactional, - }, + db::query::{q, Query}, id::Id, model::Paginated, purl::Purl, @@ -19,10 +16,10 @@ use trustify_test_context::TrustifyContext; async fn ingest_extra_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ctx.graph - .ingest_package(&Purl::from_str("pkg:maven/org.myspace/tom")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.myspace/tom")?, &ctx.db) .await?; ctx.graph - .ingest_package(&Purl::from_str("pkg:rpm/sendmail")?, ()) + .ingest_package(&Purl::from_str("pkg:rpm/sendmail")?, &ctx.db) .await?; Ok(()) @@ -31,31 +28,37 @@ async fn ingest_extra_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Erro #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn types(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 - .ingest_qualified_package(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_qualified_package( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?repository_url=http://jboss.org")?, - (), + &ctx.db, ) .await?; ingest_extra_packages(ctx).await?; - let types = service.purl_types(()).await?; + let types = service.purl_types(&ctx.db).await?; assert_eq!(2, types.len()); @@ -82,29 +85,38 @@ async fn types(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn packages_for_type(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.4")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.4")?, + &ctx.db, + ) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.5")?, + &ctx.db, + ) .await?; ingest_extra_packages(ctx).await?; let packages = service - .base_purls_by_type("maven", Query::default(), Paginated::default(), ()) + .base_purls_by_type("maven", Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(packages.total, 2); @@ -125,29 +137,38 @@ async fn packages_for_type(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn packages_for_type_with_filtering(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.4")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.4")?, + &ctx.db, + ) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.5")?, + &ctx.db, + ) .await?; ingest_extra_packages(ctx).await?; let packages = service - .base_purls_by_type("maven", q("myspace"), Paginated::default(), ()) + .base_purls_by_type("maven", q("myspace"), Paginated::default(), &ctx.db) .await?; assert_eq!(packages.total, 1); @@ -163,64 +184,73 @@ async fn packages_for_type_with_filtering(ctx: &TrustifyContext) -> Result<(), a #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?repository_url=http://maven.org")?, - (), + &ctx.db, ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?repository_url=http://jboss.org")?, - (), + &ctx.db, ) .await?; let _log4j_124 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.4")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.4")?, + &ctx.db, + ) .await?; log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.5")?, + &ctx.db, + ) .await?; let tom = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.myspace/tom")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.myspace/tom")?, &ctx.db) .await?; - tom.ingest_package_version(&Purl::from_str("pkg:maven/org.myspace/tom@1.1.1")?, ()) + tom.ingest_package_version(&Purl::from_str("pkg:maven/org.myspace/tom@1.1.1")?, &ctx.db) .await?; - tom.ingest_package_version(&Purl::from_str("pkg:maven/org.myspace/tom@9.9.9")?, ()) + tom.ingest_package_version(&Purl::from_str("pkg:maven/org.myspace/tom@9.9.9")?, &ctx.db) .await?; ctx.graph - .ingest_package(&Purl::from_str("pkg:rpm/sendmail")?, ()) + .ingest_package(&Purl::from_str("pkg:rpm/sendmail")?, &ctx.db) .await?; let bind = ctx .graph - .ingest_package(&Purl::from_str("pkg:rpm/bind")?, ()) + .ingest_package(&Purl::from_str("pkg:rpm/bind")?, &ctx.db) .await?; - bind.ingest_package_version(&Purl::from_str("pkg:rpm/bind@4.4.4")?, ()) + bind.ingest_package_version(&Purl::from_str("pkg:rpm/bind@4.4.4")?, &ctx.db) .await?; let results = service - .base_purl("maven", Some("org.apache".to_string()), "log4j", ()) + .base_purl("maven", Some("org.apache".to_string()), "log4j", &ctx.db) .await?; assert!(results.is_some()); @@ -235,46 +265,52 @@ async fn package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn package_version(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=11")?, - (), + &ctx.db, ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=17")?, - (), + &ctx.db, ) .await?; let log4j_345 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, + &ctx.db, + ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; @@ -284,7 +320,7 @@ async fn package_version(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { Some("org.apache".to_string()), "log4j", "1.2.3", - (), + &ctx.db, ) .await?; @@ -315,51 +351,57 @@ async fn package_version(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn package_version_by_uuid(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=11")?, - (), + &ctx.db, ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=17")?, - (), + &ctx.db, ) .await?; let log4j_345 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, + &ctx.db, + ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; let result = service - .versioned_purl_by_uuid(&log4j_123.package_version.id, ()) + .versioned_purl_by_uuid(&log4j_123.package_version.id, &ctx.db) .await?; assert!(result.is_some()); @@ -389,85 +431,94 @@ async fn package_version_by_uuid(ctx: &TrustifyContext) -> Result<(), anyhow::Er #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=11")?, - (), + &ctx.db, ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=17")?, - (), + &ctx.db, ) .await?; let log4j_345 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, + &ctx.db, + ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; let quarkus = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.jboss/quarkus")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.jboss/quarkus")?, &ctx.db) .await?; let quarkus_123 = quarkus - .ingest_package_version(&Purl::from_str("pkg:maven/org.jboss/quarkus@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.jboss/quarkus@1.2.3")?, + &ctx.db, + ) .await?; quarkus_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.jboss/quarkus@1.2.3?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; let results = service - .base_purls(q("log4j"), Paginated::default(), ()) + .base_purls(q("log4j"), Paginated::default(), &ctx.db) .await?; assert_eq!(1, results.items.len()); let results = service - .base_purls(q("quarkus"), Paginated::default(), ()) + .base_purls(q("quarkus"), Paginated::default(), &ctx.db) .await?; assert_eq!(1, results.items.len()); let results = service - .base_purls(q("jboss"), Paginated::default(), ()) + .base_purls(q("jboss"), Paginated::default(), &ctx.db) .await?; assert_eq!(1, results.items.len()); let results = service - .base_purls(q("maven"), Paginated::default(), ()) + .base_purls(q("maven"), Paginated::default(), &ctx.db) .await?; assert_eq!(2, results.items.len()); @@ -478,66 +529,77 @@ async fn packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn qualified_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=11")?, - (), + &ctx.db, ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=17")?, - (), + &ctx.db, ) .await?; let log4j_345 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5")?, + &ctx.db, + ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; log4j_345 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@3.4.5?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; let quarkus = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.jboss/quarkus")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.jboss/quarkus")?, &ctx.db) .await?; let quarkus_123 = quarkus - .ingest_package_version(&Purl::from_str("pkg:maven/org.jboss/quarkus@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.jboss/quarkus@1.2.3")?, + &ctx.db, + ) .await?; quarkus_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.jboss/quarkus@1.2.3?repository_url=http://jboss.org/")?, - (), + &ctx.db, ) .await?; - let results = service.purls(q("log4j"), Paginated::default(), ()).await?; + let results = service + .purls(q("log4j"), Paginated::default(), &ctx.db) + .await?; log::debug!("{:#?}", results); @@ -547,20 +609,17 @@ async fn qualified_packages(ctx: &TrustifyContext) -> Result<(), anyhow::Error> #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); ctx.ingest_documents(["osv/RUSTSEC-2021-0079.json", "cve/CVE-2021-32714.json"]) .await?; ctx.ingestor .graph() - .ingest_qualified_package( - &Purl::from_str("pkg:cargo/hyper@0.14.1")?, - Transactional::None, - ) + .ingest_qualified_package(&Purl::from_str("pkg:cargo/hyper@0.14.1")?, &ctx.db) .await?; let results = service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1, results.items.len()); @@ -568,7 +627,7 @@ async fn statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let uuid = results.items[0].head.uuid; let _results = service - .purl_by_uuid(&uuid, Default::default(), Transactional::None) + .purl_by_uuid(&uuid, Default::default(), &ctx.db) .await?; Ok(()) @@ -577,12 +636,12 @@ async fn statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn contextual_status(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); ctx.ingest_document("csaf/rhsa-2024_3666.json").await?; let results = service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await?; let tomcat_jsp = results @@ -597,7 +656,7 @@ async fn contextual_status(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let uuid = tomcat_jsp.head.uuid; let tomcat_jsp = service - .purl_by_uuid(&uuid, Default::default(), Transactional::None) + .purl_by_uuid(&uuid, Default::default(), &ctx.db) .await?; assert!(tomcat_jsp.is_some()); @@ -628,11 +687,11 @@ async fn contextual_status(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let purl_service = PurlService::new(ctx.db.clone()); + let purl_service = PurlService::new(); assert_eq!( 0, purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await? .items .len() @@ -647,7 +706,7 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!( 880, purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await? .items .len() @@ -660,7 +719,7 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!( 1490, purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await? .items .len() @@ -673,13 +732,13 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> { let sbom_service = SbomService::new(ctx.db.clone()); let sbom = sbom_service - .fetch_sbom_details(id, ()) + .fetch_sbom_details(id, &ctx.db) .await? .expect("fetch_sbom"); assert_eq!( 1, sbom_service - .delete_sbom(sbom.summary.head.id, Transactional::None) + .delete_sbom(sbom.summary.head.id, &ctx.db) .await? ); @@ -689,7 +748,7 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!( 1, advisory_service - .delete_advisory(a.head.uuid, Transactional::None) + .delete_advisory(a.head.uuid, &ctx.db) .await? ); } @@ -701,16 +760,16 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { // it should leave behind orphaned purls let result = purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(1490, result.items.len()); // running the gc, should delete those orphaned purls - let deleted_records_count = purl_service.gc_purls(()).await?; + let deleted_records_count = purl_service.gc_purls(&ctx.db).await?; assert_eq!(792, deleted_records_count); let result = purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(880, result.items.len()); @@ -720,16 +779,16 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { // it should leave behind orphaned purls let result = purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(880, result.items.len()); // running the gc, should delete those orphaned purls - let deleted_records_count = purl_service.gc_purls(()).await?; + let deleted_records_count = purl_service.gc_purls(&ctx.db).await?; assert_eq!(1759, deleted_records_count); let result = purl_service - .purls(Query::default(), Paginated::default(), Transactional::None) + .purls(Query::default(), Paginated::default(), &ctx.db) .await?; assert_eq!(0, result.items.len()); @@ -739,17 +798,20 @@ async fn gc_purls(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { async fn ingest_some_log4j_data(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let log4j = ctx .graph - .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, ()) + .ingest_package(&Purl::from_str("pkg:maven/org.apache/log4j")?, &ctx.db) .await?; let log4j_123 = log4j - .ingest_package_version(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .ingest_package_version( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; log4j_123 .ingest_qualified_package( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3?jdk=11")?, - (), + &ctx.db, ) .await?; Ok(()) @@ -758,7 +820,7 @@ async fn ingest_some_log4j_data(ctx: &TrustifyContext) -> Result<(), anyhow::Err #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); ingest_some_log4j_data(ctx).await?; @@ -766,7 +828,7 @@ async fn purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { .purl_by_purl( &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, Default::default(), - (), + &ctx.db, ) .await?; @@ -778,12 +840,15 @@ async fn purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn base_purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); ingest_some_log4j_data(ctx).await?; let results = service - .base_purl_by_purl(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .base_purl_by_purl( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; assert!(!results.unwrap().versions.is_empty()); @@ -794,12 +859,15 @@ async fn base_purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn versioned_base_purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); ingest_some_log4j_data(ctx).await?; let results = service - .versioned_purl_by_purl(&Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, ()) + .versioned_purl_by_purl( + &Purl::from_str("pkg:maven/org.apache/log4j@1.2.3")?, + &ctx.db, + ) .await?; assert!(!results.unwrap().purls.is_empty()); @@ -810,7 +878,7 @@ async fn versioned_base_purl_by_purl(ctx: &TrustifyContext) -> Result<(), anyhow #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn license_information(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); ctx.ingest_document("ubi9-9.2-755.1697625012.json").await?; @@ -818,7 +886,7 @@ async fn license_information(ctx: &TrustifyContext) -> Result<(), anyhow::Error> .purl_by_purl( &Purl::try_from("pkg:rpm/redhat/libsepol@3.5-1.el9?arch=s390x")?, Default::default(), - Transactional::None, + &ctx.db, ) .await?; diff --git a/modules/fundamental/src/sbom/endpoints/label.rs b/modules/fundamental/src/sbom/endpoints/label.rs index 8fc587ef3..5a0ad1fda 100644 --- a/modules/fundamental/src/sbom/endpoints/label.rs +++ b/modules/fundamental/src/sbom/endpoints/label.rs @@ -1,6 +1,7 @@ use crate::sbom::service::SbomService; use actix_web::{patch, put, web, HttpResponse, Responder}; use trustify_auth::{authorizer::Require, UpdateSbom}; +use trustify_common::db::Database; use trustify_common::id::Id; use trustify_entity::labels::Labels; @@ -51,12 +52,18 @@ pub async fn update( #[put("/v1/sbom/{id}/label")] pub async fn set( sbom: web::Data, + db: web::Data, id: web::Path, web::Json(labels): web::Json, _: Require, ) -> actix_web::Result { - Ok(match sbom.set_labels(id.into_inner(), labels, ()).await? { - Some(()) => HttpResponse::NoContent(), - None => HttpResponse::NotFound(), - }) + Ok( + match sbom + .set_labels(id.into_inner(), labels, db.as_ref()) + .await? + { + Some(()) => HttpResponse::NoContent(), + None => HttpResponse::NotFound(), + }, + ) } diff --git a/modules/fundamental/src/sbom/endpoints/mod.rs b/modules/fundamental/src/sbom/endpoints/mod.rs index 0117a81ea..bf704230c 100644 --- a/modules/fundamental/src/sbom/endpoints/mod.rs +++ b/modules/fundamental/src/sbom/endpoints/mod.rs @@ -19,6 +19,7 @@ use actix_web::{delete, get, http::header, post, web, HttpResponse, Responder, R use config::Config; use futures_util::TryStreamExt; use sea_orm::prelude::Uuid; +use sea_orm::TransactionTrait; use std::{ fmt::{Display, Formatter}, str::FromStr, @@ -50,9 +51,10 @@ pub fn configure( upload_limit: usize, ) { let sbom_service = SbomService::new(db.clone()); - let purl_service = PurlService::new(db); + let purl_service = PurlService::new(); config + .app_data(web::Data::new(db)) .app_data(web::Data::new(sbom_service)) .app_data(web::Data::new(purl_service)) .app_data(web::Data::new(Config { upload_limit })) @@ -84,6 +86,7 @@ pub fn configure( #[get("/v1/sbom")] pub async fn all( fetch: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, authorizer: web::Data, @@ -91,7 +94,9 @@ pub async fn all( ) -> actix_web::Result { authorizer.require(&user, Permission::ReadSbom)?; - let result = fetch.fetch_sboms(search, paginated, (), ()).await?; + let result = fetch + .fetch_sboms(search, paginated, (), db.as_ref()) + .await?; Ok(HttpResponse::Ok().json(result)) } @@ -165,6 +170,7 @@ impl TryFrom for Uuid { #[get("/v1/sbom/by-package")] pub async fn all_related( sbom: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, web::Query(all_related): web::Query, @@ -175,7 +181,9 @@ pub async fn all_related( let id = all_related.try_into()?; - let result = sbom.find_related_sboms(id, paginated, search, ()).await?; + let result = sbom + .find_related_sboms(id, paginated, search, db.as_ref()) + .await?; Ok(HttpResponse::Ok().json(result)) } @@ -197,6 +205,7 @@ pub async fn all_related( #[get("/v1/sbom/count-by-package")] pub async fn count_related( sbom: web::Data, + db: web::Data, web::Json(ids): web::Json>, _: Require, ) -> actix_web::Result { @@ -205,7 +214,7 @@ pub async fn count_related( .map(Uuid::try_from) .collect::, _>>()?; - let result = sbom.count_related_sboms(ids, ()).await?; + let result = sbom.count_related_sboms(ids, db.as_ref()).await?; Ok(HttpResponse::Ok().json(result)) } @@ -224,11 +233,12 @@ pub async fn count_related( #[get("/v1/sbom/{id}")] pub async fn get( fetcher: web::Data, + db: web::Data, id: web::Path, _: Require, ) -> actix_web::Result { let id = Id::from_str(&id).map_err(Error::IdKey)?; - match fetcher.fetch_sbom_summary(id, ()).await? { + match fetcher.fetch_sbom_summary(id, db.as_ref()).await? { Some(v) => Ok(HttpResponse::Ok().json(v)), None => Ok(HttpResponse::NotFound().finish()), } @@ -248,11 +258,12 @@ pub async fn get( #[get("/v1/sbom/{id}/advisory")] pub async fn get_sbom_advisories( fetcher: web::Data, + db: web::Data, id: web::Path, _: Require, ) -> actix_web::Result { let id = Id::from_str(&id).map_err(Error::IdKey)?; - match fetcher.fetch_sbom_details(id, ()).await? { + match fetcher.fetch_sbom_details(id, db.as_ref()).await? { Some(v) => Ok(HttpResponse::Ok().json(v.advisories)), None => Ok(HttpResponse::NotFound().finish()), } @@ -274,21 +285,25 @@ all!(GetSbomAdvisories -> ReadSbom, ReadAdvisory); #[delete("/v1/sbom/{id}")] pub async fn delete( service: web::Data, + db: web::Data, purl_service: web::Data, id: web::Path, _: Require, -) -> actix_web::Result { - let id = Id::from_str(&id).map_err(Error::IdKey)?; - match service.fetch_sbom_summary(id.clone(), ()).await? { +) -> Result { + let tx = db.begin().await?; + + let id = Id::from_str(&id)?; + match service.fetch_sbom_summary(id.clone(), &tx).await? { Some(v) => { - let rows_affected = service.delete_sbom(v.head.id, ()).await?; + let rows_affected = service.delete_sbom(v.head.id, &tx).await?; match rows_affected { 0 => Ok(HttpResponse::NotFound().finish()), 1 => { - let _ = purl_service.gc_purls(()).await; // ignore gc failure.. + let _ = purl_service.gc_purls(&tx).await; // ignore gc failure.. + tx.commit().await?; Ok(HttpResponse::Ok().json(v)) } - _ => Err(Internal("Unexpected number of rows affected".into()).into()), + _ => Err(Internal("Unexpected number of rows affected".into())), } } None => Ok(HttpResponse::NotFound().finish()), @@ -311,13 +326,14 @@ pub async fn delete( #[get("/v1/sbom/{id}/packages")] pub async fn packages( fetch: web::Data, + db: web::Data, id: web::Path, web::Query(search): web::Query, web::Query(paginated): web::Query, _: Require, ) -> actix_web::Result { let result = fetch - .fetch_sbom_packages(id.into_inner(), search, paginated, ()) + .fetch_sbom_packages(id.into_inner(), search, paginated, db.as_ref()) .await?; Ok(HttpResponse::Ok().json(result)) @@ -353,6 +369,7 @@ struct RelatedQuery { #[get("/v1/sbom/{id}/related")] pub async fn related( fetch: web::Data, + db: web::Data, id: web::Path, web::Query(search): web::Query, web::Query(paginated): web::Query, @@ -372,7 +389,7 @@ pub async fn related( Some(id) => SbomPackageReference::Package(id), }, related.relationship, - (), + db.as_ref(), ) .await?; @@ -431,13 +448,14 @@ pub async fn upload( #[get("/v1/sbom/{key}/download")] pub async fn download( ingestor: web::Data, + db: web::Data, sbom: web::Data, key: web::Path, _: Require, ) -> Result { let id = Id::from_str(&key).map_err(Error::IdKey)?; - let Some(sbom) = sbom.fetch_sbom_summary(id, ()).await? else { + let Some(sbom) = sbom.fetch_sbom_summary(id, db.as_ref()).await? else { return Ok(HttpResponse::NotFound().finish()); }; diff --git a/modules/fundamental/src/sbom/model/details.rs b/modules/fundamental/src/sbom/model/details.rs index 909fb701c..6c5d9d89b 100644 --- a/modules/fundamental/src/sbom/model/details.rs +++ b/modules/fundamental/src/sbom/model/details.rs @@ -11,8 +11,8 @@ use crate::{ }; use cpe::uri::OwnedUri; use sea_orm::{ - DbErr, EntityTrait, FromQueryResult, JoinType, ModelTrait, QueryFilter, QueryOrder, - QueryResult, QuerySelect, RelationTrait, Select, + ConnectionTrait, DbErr, EntityTrait, FromQueryResult, JoinType, ModelTrait, QueryFilter, + QueryOrder, QueryResult, QuerySelect, RelationTrait, Select, }; use sea_query::{Asterisk, Expr, Func, SimpleExpr}; use serde::{Deserialize, Serialize}; @@ -21,7 +21,7 @@ use trustify_common::{ cpe::CpeCompare, db::{ multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, - ConnectionOrTransaction, VersionMatches, + VersionMatches, }, memo::Memo, }; @@ -43,10 +43,10 @@ pub struct SbomDetails { impl SbomDetails { /// turn an (sbom, sbom_node) row into an [`SbomDetails`], if possible - pub async fn from_entity( + pub async fn from_entity( (sbom, node): (sbom::Model, Option), service: &SbomService, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let relevant_advisory_info = sbom .find_related(sbom_package::Entity) @@ -142,11 +142,11 @@ pub struct SbomAdvisory { } impl SbomAdvisory { - pub async fn from_models( + pub async fn from_models( described_by: &[SbomPackage], statuses: &[QueryCatcher], product_statuses: &[ProductStatusCatcher], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut advisories = HashMap::new(); @@ -299,12 +299,12 @@ pub struct SbomStatus { } impl SbomStatus { - pub async fn new( + pub async fn new( vulnerability: &vulnerability::Model, status: String, cpe: Option, packages: Vec, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let cvss3 = vulnerability.find_related(cvss3::Entity).all(tx).await?; let average_severity = Score::from_iter(cvss3.iter().map(Cvss3Base::from)).severity(); diff --git a/modules/fundamental/src/sbom/model/mod.rs b/modules/fundamental/src/sbom/model/mod.rs index 1f4f95bb8..e42bb67e6 100644 --- a/modules/fundamental/src/sbom/model/mod.rs +++ b/modules/fundamental/src/sbom/model/mod.rs @@ -5,10 +5,10 @@ use crate::{ purl::model::summary::purl::PurlSummary, source_document::model::SourceDocument, Error, }; use async_graphql::SimpleObject; -use sea_orm::{prelude::Uuid, ModelTrait, PaginatorTrait}; +use sea_orm::{prelude::Uuid, ConnectionTrait, ModelTrait, PaginatorTrait}; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; -use trustify_common::{db::ConnectionOrTransaction, model::Paginated}; +use trustify_common::model::Paginated; use trustify_entity::{ labels::Labels, relationship::Relationship, sbom, sbom_node, sbom_package, source_document, }; @@ -37,12 +37,12 @@ pub struct SbomHead { } impl SbomHead { - pub async fn from_entity( + pub async fn from_entity( sbom: &sbom::Model, sbom_node: Option, - tx: &ConnectionOrTransaction<'_>, + db: &C, ) -> Result { - let number_of_packages = sbom.find_related(sbom_package::Entity).count(tx).await?; + let number_of_packages = sbom.find_related(sbom_package::Entity).count(db).await?; Ok(Self { id: sbom.sbom_id, document_id: sbom.document_id.clone(), @@ -70,24 +70,24 @@ pub struct SbomSummary { } impl SbomSummary { - pub async fn from_entity( + pub async fn from_entity( (sbom, node): (sbom::Model, Option), service: &SbomService, - tx: &ConnectionOrTransaction<'_>, + db: &C, ) -> Result, Error> { // TODO: consider improving the n-select issues here let described_by = service - .describes_packages(sbom.sbom_id, Paginated::default(), ()) + .describes_packages(sbom.sbom_id, Paginated::default(), db) .await? .items; - let source_document = sbom.find_related(source_document::Entity).one(tx).await?; + let source_document = sbom.find_related(source_document::Entity).one(db).await?; Ok(match node { Some(_) => Some(SbomSummary { - head: SbomHead::from_entity(&sbom, node, tx).await?, + head: SbomHead::from_entity(&sbom, node, db).await?, source_document: if let Some(doc) = &source_document { - Some(SourceDocument::from_entity(doc, tx).await?) + Some(SourceDocument::from_entity(doc).await?) } else { None }, diff --git a/modules/fundamental/src/sbom/service/label.rs b/modules/fundamental/src/sbom/service/label.rs index 16d769e78..597ac5d32 100644 --- a/modules/fundamental/src/sbom/service/label.rs +++ b/modules/fundamental/src/sbom/service/label.rs @@ -1,13 +1,10 @@ use crate::{sbom::service::SbomService, Error}; use sea_orm::{ - ActiveModelTrait, ActiveValue::Set, DatabaseBackend, EntityTrait, IntoActiveModel, QueryTrait, - TransactionTrait, + ActiveModelTrait, ActiveValue::Set, ConnectionTrait, DatabaseBackend, EntityTrait, + IntoActiveModel, QueryTrait, TransactionTrait, }; use sea_query::Expr; -use trustify_common::{ - db::Transactional, - id::{Id, TrySelectForId}, -}; +use trustify_common::id::{Id, TrySelectForId}; use trustify_entity::{labels::Labels, sbom}; impl SbomService { @@ -15,18 +12,16 @@ impl SbomService { /// /// Returns `Ok(Some(()))` if a document was found and updated. If no document was found, it will /// return `Ok(None)`. - pub async fn set_labels( + pub async fn set_labels( &self, id: Id, labels: Labels, - tx: impl AsRef, + connection: &C, ) -> Result, Error> { - let db = self.db.connection(&tx); - let result = sbom::Entity::update_many() .try_filter(id)? .col_expr(sbom::Column::Labels, Expr::value(labels)) - .exec(&db) + .exec(connection) .await?; Ok((result.rows_affected > 0).then_some(())) diff --git a/modules/fundamental/src/sbom/service/sbom.rs b/modules/fundamental/src/sbom/service/sbom.rs index 972a787b8..d0cd746bb 100644 --- a/modules/fundamental/src/sbom/service/sbom.rs +++ b/modules/fundamental/src/sbom/service/sbom.rs @@ -9,8 +9,9 @@ use crate::{ }; use futures_util::{stream, StreamExt, TryStreamExt}; use sea_orm::{ - prelude::Uuid, ColumnTrait, DbErr, EntityTrait, FromQueryResult, IntoSimpleExpr, QueryFilter, - QueryOrder, QueryResult, QuerySelect, RelationTrait, Select, SelectColumns, + prelude::Uuid, ColumnTrait, ConnectionTrait, DbErr, EntityTrait, FromQueryResult, + IntoSimpleExpr, QueryFilter, QueryOrder, QueryResult, QuerySelect, RelationTrait, Select, + SelectColumns, }; use sea_query::{extension::postgres::PgExpr, Expr, Func, JoinType, SimpleExpr}; use serde::Deserialize; @@ -23,7 +24,7 @@ use trustify_common::{ limiter::{limit_selector, LimiterTrait}, multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, query::{Columns, Filtering, IntoColumns, Query}, - ArrayAgg, ConnectionOrTransaction, JsonBuildObject, ToJson, Transactional, + ArrayAgg, JsonBuildObject, ToJson, }, id::{Id, TrySelectForId}, model::{Paginated, PaginatedResults}, @@ -41,74 +42,67 @@ use trustify_entity::{ }; impl SbomService { - async fn fetch_sbom>( + async fn fetch_sbom( &self, id: Id, - tx: TX, + connection: &C, ) -> Result)>, Error> { - let connection = self.db.connection(&tx); - let select = sbom::Entity::find() .join(JoinType::LeftJoin, sbom::Relation::SourceDocument.def()) .try_filter(id)?; Ok(select .find_also_linked(SbomNodeLink) - .one(&connection) + .one(connection) .await?) } /// fetch one sbom - pub async fn fetch_sbom_details>( + pub async fn fetch_sbom_details( &self, id: Id, - tx: TX, + connection: &C, ) -> Result, Error> { - Ok(match self.fetch_sbom(id, &tx).await? { - Some(row) => SbomDetails::from_entity(row, self, &self.db.connection(&tx)).await?, + Ok(match self.fetch_sbom(id, connection).await? { + Some(row) => SbomDetails::from_entity(row, self, connection).await?, None => None, }) } /// fetch the summary of one sbom - pub async fn fetch_sbom_summary>( + pub async fn fetch_sbom_summary( &self, id: Id, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - - Ok(match self.fetch_sbom(id, &tx).await? { - Some(row) => SbomSummary::from_entity(row, self, &connection).await?, + Ok(match self.fetch_sbom(id, connection).await? { + Some(row) => SbomSummary::from_entity(row, self, connection).await?, None => None, }) } /// delete one sbom - pub async fn delete_sbom>( + pub async fn delete_sbom( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result { - let connection = self.db.connection(&tx); - let query = sbom::Entity::delete_by_id(id); - let result = query.exec(&connection).await?; + let result = query.exec(connection).await?; Ok(result.rows_affected) } /// fetch all SBOMs - pub async fn fetch_sboms>( + pub async fn fetch_sboms( &self, search: Query, paginated: Paginated, labels: impl Into, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); let labels = labels.into(); let query = if labels.is_empty() { @@ -124,13 +118,13 @@ impl SbomService { .add_columns(sbom_node::Entity) .alias("sbom_node", "r0"), )? - .limiting(&connection, paginated.offset, paginated.limit); + .limiting(connection, paginated.offset, paginated.limit); let total = limiter.total().await?; let sboms = limiter.fetch().await?; let items = stream::iter(sboms.into_iter()) - .then(|row| async { SbomSummary::from_entity(row, self, &connection).await }) + .then(|row| async { SbomSummary::from_entity(row, self, connection).await }) .try_filter_map(futures_util::future::ok) .try_collect() .await?; @@ -142,16 +136,14 @@ impl SbomService { /// /// If you need to find packages based on their relationship, even in the relationship to /// SBOM itself, use [`Self::fetch_related_packages`]. - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn fetch_sbom_packages>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn fetch_sbom_packages( &self, sbom_id: Uuid, search: Query, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let db = self.db.connection(&tx); - let mut query = sbom_package::Entity::find() .filter(sbom_package::Column::SbomId.eq(sbom_id)) .join(JoinType::Join, sbom_package::Relation::Node.def()) @@ -182,7 +174,7 @@ impl SbomService { // limit and execute let limiter = limit_selector::<'_, _, _, _, PackageCatcher>( - &db, + connection, query, paginated.offset, paginated.limit, @@ -196,19 +188,19 @@ impl SbomService { let mut items = Vec::new(); for row in packages { - items.push(package_from_row(row, &self.db.connection(&tx)).await?); + items.push(package_from_row(row, connection).await?); } Ok(PaginatedResults { items, total }) } /// Get all packages describing the SBOM. - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn describes_packages>( + #[instrument(skip(self, db), err(level=tracing::Level::INFO))] + pub async fn describes_packages( &self, sbom_id: Uuid, paginated: Paginated, - tx: TX, + db: &C, ) -> Result, Error> { self.fetch_related_packages( sbom_id, @@ -217,20 +209,18 @@ impl SbomService { Which::Right, SbomPackageReference::All, Some(Relationship::DescribedBy), - tx, + db, ) .await .map(|r| r.map(|rel| rel.package)) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn count_related_sboms( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn count_related_sboms( &self, qualified_package_ids: Vec, - tx: impl AsRef, + connection: &C, ) -> Result, Error> { - let db = self.db.connection(&tx); - let query = sbom::Entity::find() .join(JoinType::Join, sbom::Relation::Packages.def()) .join(JoinType::Join, sbom_package::Relation::Purl.def()) @@ -242,7 +232,7 @@ impl SbomService { .column(sbom_package_purl_ref::Column::QualifiedPurlId) .column_as(sbom_package::Column::SbomId.count(), "count") .into_tuple::<(Uuid, i64)>() - .all(&db) + .all(connection) .await?; // turn result into a map @@ -261,16 +251,14 @@ impl SbomService { Ok(result) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn find_related_sboms( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn find_related_sboms( &self, qualified_package_id: Uuid, paginated: Paginated, query: Query, - tx: impl AsRef, + connection: &C, ) -> Result, Error> { - let db = self.db.connection(&tx); - let query = sbom::Entity::find() .join(JoinType::Join, sbom::Relation::Packages.def()) .join(JoinType::Join, sbom_package::Relation::Purl.def()) @@ -280,7 +268,7 @@ impl SbomService { // limit and execute - let limiter = query.limiting(&db, paginated.offset, paginated.limit); + let limiter = query.limiting(connection, paginated.offset, paginated.limit); let total = limiter.total().await?; let sboms = limiter.fetch().await?; @@ -288,7 +276,7 @@ impl SbomService { // collect results let items = stream::iter(sboms.into_iter()) - .then(|row| async { SbomSummary::from_entity(row, self, &db).await }) + .then(|row| async { SbomSummary::from_entity(row, self, connection).await }) .try_filter_map(futures_util::future::ok) .try_collect() .await?; @@ -298,8 +286,8 @@ impl SbomService { /// Fetch all related packages in the context of an SBOM. #[allow(clippy::too_many_arguments)] - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn fetch_related_packages>( + #[instrument(skip(self, db), err(level=tracing::Level::INFO))] + pub async fn fetch_related_packages( &self, sbom_id: Uuid, search: Query, @@ -307,9 +295,9 @@ impl SbomService { which: Which, reference: impl Into> + Debug, relationship: Option, - tx: TX, + db: &C, ) -> Result, Error> { - let db = self.db.connection(&tx); + // let db = self.db.connection(connection); // which way @@ -378,7 +366,7 @@ impl SbomService { // limit and execute let limiter = limit_selector::<'_, _, _, _, PackageCatcher>( - &db, + db, query, paginated.offset, paginated.limit, @@ -395,7 +383,7 @@ impl SbomService { if let Some(relationship) = row.relationship { items.push(SbomPackageRelation { relationship, - package: package_from_row(row, &self.db.connection(&tx)).await?, + package: package_from_row(row, db).await?, }); } } @@ -406,12 +394,12 @@ impl SbomService { /// A simplified version of [`Self::fetch_related_packages`]. /// /// It uses [`Which::Right`] and the provided reference, [`Default::default`] for the rest. - pub async fn related_packages>( + pub async fn related_packages( &self, sbom_id: Uuid, relationship: impl Into>, pkg: impl Into> + Debug, - tx: TX, + tx: &C, ) -> Result, Error> { let result = self .fetch_related_packages( @@ -526,9 +514,9 @@ struct PackageCatcher { } /// Convert values from a "package row" into an SBOM package -async fn package_from_row( +async fn package_from_row( row: PackageCatcher, - tx: &ConnectionOrTransaction<'_>, + db: &C, ) -> Result { let mut purls = Vec::new(); @@ -565,7 +553,7 @@ async fn package_from_row( qualifiers: dto.qualifiers, purl: cp, }, - tx, + db, ) .await?, ); @@ -703,7 +691,7 @@ mod test { &Digests::digest("RHSA-1"), "http://redhat.com/test.json", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v1_again = ctx @@ -713,7 +701,7 @@ mod test { &Digests::digest("RHSA-1"), "http://redhat.com/test.json", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v2 = ctx @@ -723,7 +711,7 @@ mod test { &Digests::digest("RHSA-2"), "http://myspace.com/test.json", (), - Transactional::None, + &ctx.db, ) .await?; @@ -734,7 +722,7 @@ mod test { &Digests::digest("RHSA-3"), "http://geocities.com/other.json", (), - Transactional::None, + &ctx.db, ) .await?; @@ -748,7 +736,7 @@ mod test { q("MySpAcE").sort("name,authors,published"), Paginated::default(), (), - (), + &ctx.db, ) .await?; @@ -771,7 +759,7 @@ mod test { &Digests::digest("RHSA-1"), "http://redhat.com/test1.json", (), - Transactional::None, + &ctx.db, ) .await?; @@ -785,7 +773,7 @@ mod test { &Digests::digest("RHSA-2"), "http://redhat.com/test2.json", (), - Transactional::None, + &ctx.db, ) .await?; @@ -799,34 +787,54 @@ mod test { &Digests::digest("RHSA-3"), "http://redhat.com/test3.json", (), - Transactional::None, + &ctx.db, ) .await?; let service = SbomService::new(ctx.db.clone()); let fetched = service - .fetch_sboms(Query::default(), Paginated::default(), ("ci", "job1"), ()) + .fetch_sboms( + Query::default(), + Paginated::default(), + ("ci", "job1"), + &ctx.db, + ) .await?; assert_eq!(1, fetched.total); let fetched = service - .fetch_sboms(Query::default(), Paginated::default(), ("ci", "job2"), ()) + .fetch_sboms( + Query::default(), + Paginated::default(), + ("ci", "job2"), + &ctx.db, + ) .await?; assert_eq!(2, fetched.total); let fetched = service - .fetch_sboms(Query::default(), Paginated::default(), ("ci", "job3"), ()) + .fetch_sboms( + Query::default(), + Paginated::default(), + ("ci", "job3"), + &ctx.db, + ) .await?; assert_eq!(0, fetched.total); let fetched = service - .fetch_sboms(Query::default(), Paginated::default(), ("foo", "bar"), ()) + .fetch_sboms( + Query::default(), + Paginated::default(), + ("foo", "bar"), + &ctx.db, + ) .await?; assert_eq!(0, fetched.total); let fetched = service - .fetch_sboms(Query::default(), Paginated::default(), (), ()) + .fetch_sboms(Query::default(), Paginated::default(), (), &ctx.db) .await?; assert_eq!(3, fetched.total); @@ -835,7 +843,7 @@ mod test { Query::default(), Paginated::default(), [("ci", "job2"), ("team", "a")], - (), + &ctx.db, ) .await?; assert_eq!(1, fetched.total); @@ -853,18 +861,18 @@ mod test { &Digests::digest("RHSA-1"), "http://redhat.com/test.json", (), - Transactional::None, + &ctx.db, ) .await?; let service = SbomService::new(ctx.db.clone()); - let affected = service.delete_sbom(sbom_v1.sbom.sbom_id, ()).await?; + let affected = service.delete_sbom(sbom_v1.sbom.sbom_id, &ctx.db).await?; log::debug!("{:#?}", affected); assert_eq!(1, affected); - let affected = service.delete_sbom(sbom_v1.sbom.sbom_id, ()).await?; + let affected = service.delete_sbom(sbom_v1.sbom.sbom_id, &ctx.db).await?; log::debug!("{:#?}", affected); assert_eq!(0, affected); diff --git a/modules/fundamental/src/sbom/service/test.rs b/modules/fundamental/src/sbom/service/test.rs index 76776223c..850997c16 100644 --- a/modules/fundamental/src/sbom/service/test.rs +++ b/modules/fundamental/src/sbom/service/test.rs @@ -2,7 +2,6 @@ use crate::sbom::service::SbomService; use std::str::FromStr; use test_context::test_context; use test_log::test; -use trustify_common::db::Transactional; use trustify_common::id::Id; use trustify_common::purl::Purl; use trustify_test_context::TrustifyContext; @@ -23,9 +22,7 @@ async fn sbom_details_status(ctx: &TrustifyContext) -> Result<(), anyhow::Error> let id_3_2_12 = results[3].id.clone(); - let details = service - .fetch_sbom_details(id_3_2_12, Transactional::None) - .await?; + let details = service.fetch_sbom_details(id_3_2_12, &ctx.db).await?; assert!(details.is_some()); @@ -34,7 +31,7 @@ async fn sbom_details_status(ctx: &TrustifyContext) -> Result<(), anyhow::Error> log::debug!("{details:#?}"); let details = service - .fetch_sbom_details(Id::Uuid(details.summary.head.id), Transactional::None) + .fetch_sbom_details(Id::Uuid(details.summary.head.id), &ctx.db) .await?; assert!(details.is_some()); @@ -64,7 +61,7 @@ async fn count_sboms(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { both.qualifier_uuid(), one.qualifier_uuid(), ], - (), + &ctx.db, ) .await?; diff --git a/modules/fundamental/src/source_document/model/mod.rs b/modules/fundamental/src/source_document/model/mod.rs index d41b9ee17..15696f758 100644 --- a/modules/fundamental/src/source_document/model/mod.rs +++ b/modules/fundamental/src/source_document/model/mod.rs @@ -1,7 +1,6 @@ use crate::Error; use serde::{Deserialize, Serialize}; use std::str::FromStr; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::id::{Id, IdError}; use trustify_entity::source_document; use trustify_module_storage::service::StorageKey; @@ -16,10 +15,7 @@ pub struct SourceDocument { } impl SourceDocument { - pub async fn from_entity( - source_document: &source_document::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(source_document: &source_document::Model) -> Result { Ok(Self { sha256: format!("sha256:{}", source_document.sha256), sha384: format!("sha384:{}", source_document.sha384), diff --git a/modules/fundamental/src/vulnerability/endpoints/mod.rs b/modules/fundamental/src/vulnerability/endpoints/mod.rs index b32887bac..3a0bda172 100644 --- a/modules/fundamental/src/vulnerability/endpoints/mod.rs +++ b/modules/fundamental/src/vulnerability/endpoints/mod.rs @@ -7,19 +7,22 @@ use crate::{ model::{VulnerabilityDetails, VulnerabilitySummary}, service::VulnerabilityService, }, + Error, Error::Internal, }; use actix_web::{delete, get, web, HttpResponse, Responder}; +use sea_orm::TransactionTrait; use trustify_auth::{authorizer::Require, DeleteVulnerability, ReadAdvisory}; use trustify_common::{ - db::{query::Query, Database, Transactional}, + db::{query::Query, Database}, model::{Paginated, PaginatedResults}, }; pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) { - let service = VulnerabilityService::new(db); + let service = VulnerabilityService::new(); config .app_data(web::Data::new(service)) + .app_data(web::Data::new(db)) .service(all) .service(delete) .service(get); @@ -40,6 +43,7 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d /// List vulnerabilities pub async fn all( state: web::Data, + db: web::Data, web::Query(search): web::Query, web::Query(paginated): web::Query, web::Query(Deprecation { deprecated }): web::Query, @@ -47,7 +51,7 @@ pub async fn all( ) -> actix_web::Result { Ok(HttpResponse::Ok().json( state - .fetch_vulnerabilities(search, paginated, deprecated, Transactional::None) + .fetch_vulnerabilities(search, paginated, deprecated, db.as_ref()) .await?, )) } @@ -67,12 +71,13 @@ pub async fn all( /// Retrieve vulnerability details pub async fn get( state: web::Data, + db: web::Data, id: web::Path, web::Query(Deprecation { deprecated }): web::Query, _: Require, ) -> actix_web::Result { let vuln = state - .fetch_vulnerability(&id, deprecated, Transactional::None) + .fetch_vulnerability(&id, deprecated, db.as_ref()) .await?; if let Some(vuln) = vuln { Ok(HttpResponse::Ok().json(vuln)) @@ -96,26 +101,32 @@ pub async fn get( /// Delete vulnerability pub async fn delete( state: web::Data, + db: web::Data, id: web::Path, _: Require, -) -> actix_web::Result { +) -> Result { + let tx = db.begin().await?; + let vuln = state // we ignore deprecated advisories, as we delete the vulnerability anyway. .fetch_vulnerability( &id, trustify_module_ingestor::common::Deprecation::Ignore, - Transactional::None, + &tx, ) .await?; if let Some(vuln) = vuln { let rows_affected = state - .delete_vulnerability(&vuln.head.identifier, ()) + .delete_vulnerability(&vuln.head.identifier, &tx) .await?; match rows_affected { 0 => Ok(HttpResponse::NotFound().finish()), - 1 => Ok(HttpResponse::Ok().json(vuln)), - _ => Err(Internal("Unexpected number of rows affected".into()).into()), + 1 => { + tx.commit().await?; + Ok(HttpResponse::Ok().json(vuln)) + } + _ => Err(Internal("Unexpected number of rows affected".into())), } } else { Ok(HttpResponse::NotFound().finish()) diff --git a/modules/fundamental/src/vulnerability/endpoints/test.rs b/modules/fundamental/src/vulnerability/endpoints/test.rs index 9a34dee55..ef2495076 100644 --- a/modules/fundamental/src/vulnerability/endpoints/test.rs +++ b/modules/fundamental/src/vulnerability/endpoints/test.rs @@ -6,15 +6,14 @@ use serde_json::Value; use test_context::test_context; use test_log::test; use time::{macros::datetime, OffsetDateTime}; -use trustify_common::db::Transactional; -use trustify_common::hashing::Digests; -use trustify_common::model::PaginatedResults; +use trustify_common::{hashing::Digests, model::PaginatedResults}; use trustify_cvss::cvss3::{ AttackComplexity, AttackVector, Availability, Confidentiality, Cvss3Base, Integrity, PrivilegesRequired, Scope, UserInteraction, }; -use trustify_module_ingestor::graph::advisory::AdvisoryInformation; -use trustify_module_ingestor::graph::vulnerability::VulnerabilityInformation; +use trustify_module_ingestor::graph::{ + advisory::AdvisoryInformation, vulnerability::VulnerabilityInformation, +}; use trustify_test_context::{call::CallService, TrustifyContext}; #[test_context(TrustifyContext)] @@ -37,12 +36,12 @@ async fn all_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error> modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln .ingest_cvss3_score( @@ -57,7 +56,7 @@ async fn all_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error> i: Integrity::None, a: Availability::None, }, - (), + &ctx.db, ) .await?; @@ -76,16 +75,20 @@ async fn all_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error> modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; advisory - .link_to_vulnerability("CVE-345", None, Transactional::None) + .link_to_vulnerability("CVE-345", None, &ctx.db) .await?; - ctx.graph.ingest_vulnerability("CVE-123", (), ()).await?; - ctx.graph.ingest_vulnerability("CVE-345", (), ()).await?; + ctx.graph + .ingest_vulnerability("CVE-123", (), &ctx.db) + .await?; + ctx.graph + .ingest_vulnerability("CVE-345", (), &ctx.db) + .await?; let uri = "/api/v1/vulnerability"; @@ -121,12 +124,12 @@ async fn one_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln @@ -142,7 +145,7 @@ async fn one_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { i: Integrity::High, a: Availability::Low, }, - (), + &ctx.db, ) .await?; @@ -161,12 +164,12 @@ async fn one_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; advisory - .link_to_vulnerability("CVE-345", None, Transactional::None) + .link_to_vulnerability("CVE-345", None, &ctx.db) .await?; ctx.graph @@ -180,7 +183,7 @@ async fn one_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { withdrawn: None, cwes: None, }, - (), + &ctx.db, ) .await?; @@ -216,12 +219,12 @@ async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; let advisory_vuln = advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory_vuln @@ -237,7 +240,7 @@ async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error i: Integrity::High, a: Availability::Low, }, - (), + &ctx.db, ) .await?; @@ -256,12 +259,12 @@ async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error modified: None, withdrawn: None, }, - (), + &ctx.db, ) .await?; advisory - .link_to_vulnerability("CVE-345", None, Transactional::None) + .link_to_vulnerability("CVE-345", None, &ctx.db) .await?; ctx.graph @@ -275,7 +278,7 @@ async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error withdrawn: None, cwes: None, }, - (), + &ctx.db, ) .await?; diff --git a/modules/fundamental/src/vulnerability/model/details/mod.rs b/modules/fundamental/src/vulnerability/model/details/mod.rs index c08720484..9d201c7ff 100644 --- a/modules/fundamental/src/vulnerability/model/details/mod.rs +++ b/modules/fundamental/src/vulnerability/model/details/mod.rs @@ -3,9 +3,9 @@ mod vulnerability_advisory; pub use vulnerability_advisory::*; use crate::{vulnerability::model::VulnerabilityHead, Error}; -use sea_orm::ModelTrait; +use sea_orm::{ConnectionTrait, ModelTrait}; use serde::{Deserialize, Serialize}; -use trustify_common::{db::ConnectionOrTransaction, memo::Memo}; +use trustify_common::memo::Memo; use trustify_cvss::cvss3::{score::Score, severity::Severity, Cvss3Base}; use trustify_entity::{advisory_vulnerability, cvss3, vulnerability}; use trustify_module_ingestor::common::{Deprecation, DeprecationForExt}; @@ -29,10 +29,10 @@ pub struct VulnerabilityDetails { } impl VulnerabilityDetails { - pub async fn from_entity( + pub async fn from_entity( vulnerability: &vulnerability::Model, deprecation: Deprecation, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let advisory_vulnerabilities = vulnerability .find_related(advisory_vulnerability::Entity) diff --git a/modules/fundamental/src/vulnerability/model/details/vulnerability_advisory.rs b/modules/fundamental/src/vulnerability/model/details/vulnerability_advisory.rs index 8500bf74e..83ad695ec 100644 --- a/modules/fundamental/src/vulnerability/model/details/vulnerability_advisory.rs +++ b/modules/fundamental/src/vulnerability/model/details/vulnerability_advisory.rs @@ -7,8 +7,9 @@ use crate::{ use ::cpe::cpe::Cpe; use ::cpe::uri::OwnedUri; use sea_orm::{ - ColumnTrait, DbErr, EntityTrait, FromQueryResult, IntoIdentity, LoaderTrait, ModelTrait, - PaginatorTrait, QueryFilter, QueryOrder, QueryResult, QuerySelect, RelationTrait, Select, + ColumnTrait, ConnectionTrait, DbErr, EntityTrait, FromQueryResult, IntoIdentity, LoaderTrait, + ModelTrait, PaginatorTrait, QueryFilter, QueryOrder, QueryResult, QuerySelect, RelationTrait, + Select, }; use sea_query::{Asterisk, Expr, Func, IntoCondition, JoinType, NullOrdering, SimpleExpr}; use serde::{Deserialize, Serialize}; @@ -17,17 +18,16 @@ use trustify_common::{ cpe::CpeCompare, db::{ multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, - ConnectionOrTransaction, VersionMatches, + VersionMatches, }, memo::Memo, purl::Purl, }; use trustify_cvss::cvss3::{score::Score, severity::Severity, Cvss3Base}; -use trustify_entity::{self as entity}; use trustify_entity::{ - advisory, advisory_vulnerability, base_purl, cpe, cvss3, organization, purl_status, - qualified_purl, sbom, sbom_node, sbom_package, sbom_package_cpe_ref, sbom_package_purl_ref, - status, version_range, versioned_purl, vulnerability, + self as entity, advisory, advisory_vulnerability, base_purl, cpe, cvss3, organization, + purl_status, qualified_purl, sbom, sbom_node, sbom_package, sbom_package_cpe_ref, + sbom_package_purl_ref, status, version_range, versioned_purl, vulnerability, }; use utoipa::ToSchema; use uuid::Uuid; @@ -43,10 +43,10 @@ pub struct VulnerabilityAdvisoryHead { } impl VulnerabilityAdvisoryHead { - pub async fn from_entity( + pub async fn from_entity( vulnerability: &vulnerability::Model, advisory_vulnerability: &advisory_vulnerability::Model, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let cvss3 = cvss3::Entity::find() .filter(cvss3::Column::AdvisoryId.eq(advisory_vulnerability.advisory_id)) @@ -74,11 +74,11 @@ impl VulnerabilityAdvisoryHead { Err(Error::Data("Underlying advisory is missing".to_string())) } } - pub async fn from_entities( + pub async fn from_entities( vulnerability: &vulnerability::Model, vuln_advisories: &[advisory::Model], vuln_cvss3s: &[cvss3::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let mut heads = Vec::new(); @@ -127,11 +127,11 @@ pub struct VulnerabilityAdvisorySummary { } impl VulnerabilityAdvisorySummary { - pub async fn from_entities( + pub async fn from_entities( vulnerability: &vulnerability::Model, advisory_vulnerabilities: &[advisory_vulnerability::Model], vuln_cvss3: &[cvss3::Model], - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let purl_status_query = purl_status::Entity::find() .left_join(status::Entity) @@ -304,7 +304,6 @@ impl VulnerabilityAdvisorySummary { purls: VulnerabilityAdvisoryStatus::from_models( purl_statuses, vuln_product_statuses.iter(), - tx, ) .await?, sboms: VulnerabilitySbomStatus::from_models( @@ -443,7 +442,6 @@ impl VulnerabilityAdvisoryStatus { >( purls: I, products: J, - _tx: &ConnectionOrTransaction<'_>, ) -> Result>, Error> { let mut statuses = HashMap::new(); @@ -614,14 +612,15 @@ pub struct VulnerabilitySbomStatus { } impl VulnerabilitySbomStatus { - async fn from_models<'i, 'j, I, J>( + async fn from_models<'i, 'j, I, J, C>( sbom_purl_status: I, products: J, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> where I: Iterator + Clone, J: Iterator + Clone, + C: ConnectionTrait, { let mut sboms = HashMap::new(); diff --git a/modules/fundamental/src/vulnerability/model/mod.rs b/modules/fundamental/src/vulnerability/model/mod.rs index 4cc406b97..bba19ec2a 100644 --- a/modules/fundamental/src/vulnerability/model/mod.rs +++ b/modules/fundamental/src/vulnerability/model/mod.rs @@ -3,13 +3,12 @@ mod summary; use async_graphql::SimpleObject; pub use details::*; -use sea_orm::{ColumnTrait, ModelTrait, QueryFilter}; +use sea_orm::{ColumnTrait, ConnectionTrait, ModelTrait, QueryFilter}; pub use summary::*; use crate::Error; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::memo::Memo; use trustify_entity::{advisory_vulnerability, vulnerability, vulnerability_description}; use utoipa::ToSchema; @@ -68,10 +67,10 @@ pub struct VulnerabilityHead { } impl VulnerabilityHead { - pub async fn from_vulnerability_entity( + pub async fn from_vulnerability_entity( entity: &vulnerability::Model, description: Memo, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result { let description = match description { Memo::Provided(inner) => inner.map(|inner| inner.description), diff --git a/modules/fundamental/src/vulnerability/model/summary.rs b/modules/fundamental/src/vulnerability/model/summary.rs index 6198f6dee..20bbd79ae 100644 --- a/modules/fundamental/src/vulnerability/model/summary.rs +++ b/modules/fundamental/src/vulnerability/model/summary.rs @@ -2,9 +2,8 @@ use crate::{ vulnerability::model::{VulnerabilityAdvisoryHead, VulnerabilityHead}, Error, }; -use sea_orm::{ColumnTrait, EntityTrait, LoaderTrait, QueryFilter}; +use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, LoaderTrait, QueryFilter}; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_common::memo::Memo; use trustify_cvss::cvss3::severity::Severity; use trustify_entity::{ @@ -31,11 +30,11 @@ pub struct VulnerabilitySummary { } impl VulnerabilitySummary { - pub async fn from_entities( + pub async fn from_entities( vulnerabilities: &[vulnerability::Model], averages: &[(Option, Option)], deprecation: Deprecation, - tx: &ConnectionOrTransaction<'_>, + tx: &C, ) -> Result, Error> { let advisories = vulnerabilities .load_many_to_many( diff --git a/modules/fundamental/src/vulnerability/service/mod.rs b/modules/fundamental/src/vulnerability/service/mod.rs index 140ebf4a3..4caceb856 100644 --- a/modules/fundamental/src/vulnerability/service/mod.rs +++ b/modules/fundamental/src/vulnerability/service/mod.rs @@ -9,7 +9,6 @@ use trustify_common::{ limiter::LimiterAsModelTrait, multi_model::{FromQueryResultMultiModel, SelectIntoMultiModel}, query::{Columns, Filtering, Query}, - Database, Transactional, }, model::{Paginated, PaginatedResults}, }; @@ -19,24 +18,21 @@ use trustify_entity::{ }; use trustify_module_ingestor::common::Deprecation; -pub struct VulnerabilityService { - db: Database, -} +#[derive(Default)] +pub struct VulnerabilityService {} impl VulnerabilityService { - pub fn new(db: Database) -> Self { - Self { db } + pub fn new() -> Self { + Self {} } - pub async fn fetch_vulnerabilities + Sync + Send>( + pub async fn fetch_vulnerabilities( &self, search: Query, paginated: Paginated, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - let inner_query = vulnerability::Entity::find() .left_join(cvss3::Entity) .expr_as_( @@ -102,7 +98,7 @@ impl VulnerabilityService { }), )? .try_limiting_as_multi_model::( - &connection, + connection, paginated.offset, paginated.limit, )?; @@ -124,42 +120,38 @@ impl VulnerabilityService { &vulnerabilities, &averages, deprecation, - &connection, + connection, ) .await?, }) } - pub async fn fetch_vulnerability + Sync + Send>( + pub async fn fetch_vulnerability( &self, identifier: &str, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.db.connection(&tx); - if let Some(vulnerability) = vulnerability::Entity::find_by_id(identifier) - .one(&connection) + .one(connection) .await? { Ok(Some( - VulnerabilityDetails::from_entity(&vulnerability, deprecation, &connection).await?, + VulnerabilityDetails::from_entity(&vulnerability, deprecation, connection).await?, )) } else { Ok(None) } } - pub async fn delete_vulnerability + Sync + Send>( + pub async fn delete_vulnerability( &self, id: &str, - tx: TX, + connection: &C, ) -> Result { - let connection = self.db.connection(&tx); - let query = vulnerability::Entity::delete_by_id(id); - let result = query.exec(&connection).await?; + let result = query.exec(connection).await?; Ok(result.rows_affected) } diff --git a/modules/fundamental/src/vulnerability/service/test.rs b/modules/fundamental/src/vulnerability/service/test.rs index 90f35eeb3..b24e191e3 100644 --- a/modules/fundamental/src/vulnerability/service/test.rs +++ b/modules/fundamental/src/vulnerability/service/test.rs @@ -5,7 +5,6 @@ use crate::vulnerability::service::VulnerabilityService; use test_context::test_context; use test_log::test; use trustify_common::db::query::{q, Query}; -use trustify_common::db::Transactional; use trustify_common::model::Paginated; use trustify_common::purl::Purl; use trustify_test_context::TrustifyContext; @@ -13,7 +12,7 @@ use trustify_test_context::TrustifyContext; #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn all_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = VulnerabilityService::new(ctx.db.clone()); + let service = VulnerabilityService::new(); ctx.ingest_documents(["osv/RUSTSEC-2021-0079.json", "cve/CVE-2021-32714.json"]) .await?; @@ -23,7 +22,7 @@ async fn all_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error> Query::default(), Paginated::default(), Default::default(), - (), + &ctx.db, ) .await?; @@ -42,13 +41,13 @@ async fn all_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error> #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = VulnerabilityService::new(ctx.db.clone()); + let service = VulnerabilityService::new(); ctx.ingest_documents(["osv/RUSTSEC-2021-0079.json", "cve/CVE-2021-32714.json"]) .await?; let vuln = service - .fetch_vulnerability("CVE-2021-32714", Default::default(), Transactional::None) + .fetch_vulnerability("CVE-2021-32714", Default::default(), &ctx.db) .await?; assert!(vuln.is_some()); @@ -88,7 +87,7 @@ async fn statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn statuses_too(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = VulnerabilityService::new(ctx.db.clone()); + let service = VulnerabilityService::new(); ctx.ingest_documents([ "cve/CVE-2024-29025.json", @@ -99,7 +98,7 @@ async fn statuses_too(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { .await?; let vuln = service - .fetch_vulnerability("CVE-2024-29025", Default::default(), Transactional::None) + .fetch_vulnerability("CVE-2024-29025", Default::default(), &ctx.db) .await?; assert!(vuln.is_some()); @@ -116,7 +115,7 @@ async fn statuses_too(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn commons_compress(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let vuln_service = VulnerabilityService::new(ctx.db.clone()); + let vuln_service = VulnerabilityService::new(); let sbom_service = SbomService::new(ctx.db.clone()); // Ingest a CVE declaring the vulnerability present in versions @@ -133,9 +132,7 @@ async fn commons_compress(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let sat_id = ingest_results[1].id.clone(); - let sat_sbom = sbom_service - .fetch_sbom_details(sat_id, Transactional::None) - .await?; + let sat_sbom = sbom_service.fetch_sbom_details(sat_id, &ctx.db).await?; assert!(sat_sbom.is_some()); let sat_sbom = sat_sbom.unwrap(); @@ -157,9 +154,7 @@ async fn commons_compress(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let quarkus_id = ingest_results[3].id.clone(); - let quarkus_sbom = sbom_service - .fetch_sbom_details(quarkus_id, Transactional::None) - .await?; + let quarkus_sbom = sbom_service.fetch_sbom_details(quarkus_id, &ctx.db).await?; assert!(quarkus_sbom.is_some()); @@ -169,7 +164,7 @@ async fn commons_compress(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert!(quarkus_sbom.advisories.is_empty()); let vuln = vuln_service - .fetch_vulnerability("CVE-2024-26308", Default::default(), Transactional::None) + .fetch_vulnerability("CVE-2024-26308", Default::default(), &ctx.db) .await? .unwrap(); @@ -209,9 +204,9 @@ async fn commons_compress(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn product_statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let vuln_service = VulnerabilityService::new(ctx.db.clone()); + let vuln_service = VulnerabilityService::new(); let sbom_service = SbomService::new(ctx.db.clone()); - let purl_service = PurlService::new(ctx.db.clone()); + let purl_service = PurlService::new(); let ingest_results = ctx .ingest_documents([ @@ -222,9 +217,7 @@ async fn product_statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let quarkus_id = ingest_results[1].id.clone(); - let quarkus_sbom = sbom_service - .fetch_sbom_details(quarkus_id, Transactional::None) - .await?; + let quarkus_sbom = sbom_service.fetch_sbom_details(quarkus_id, &ctx.db).await?; assert!(quarkus_sbom.is_some()); @@ -239,7 +232,7 @@ async fn product_statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!(quarkus_adv.vulnerability.identifier, "CVE-2023-0044"); let vuln = vuln_service - .fetch_vulnerability("CVE-2023-0044", Default::default(), Transactional::None) + .fetch_vulnerability("CVE-2023-0044", Default::default(), &ctx.db) .await?; assert!(vuln.is_some()); @@ -290,7 +283,7 @@ async fn product_statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { .purl_by_purl( &Purl::try_from("pkg:maven/io.quarkus/quarkus-vertx-http@2.13.8.Final-redhat-00004")?, Default::default(), - Transactional::None, + &ctx.db, ) .await?; @@ -311,12 +304,12 @@ async fn product_statuses(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = VulnerabilityService::new(ctx.db.clone()); + let service = VulnerabilityService::new(); ctx.ingest_documents(["cve/CVE-2024-29025.json"]).await?; let vuln = service - .fetch_vulnerability("CVE-2024-29025", Default::default(), ()) + .fetch_vulnerability("CVE-2024-29025", Default::default(), &ctx.db) .await? .expect("Vulnerability not found"); @@ -324,15 +317,15 @@ async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error let id = &vuln.advisories[0].head.head.identifier; - let affected = service.delete_vulnerability(id, ()).await?; + let affected = service.delete_vulnerability(id, &ctx.db).await?; assert_eq!(1, affected); assert!(service - .fetch_vulnerability("CVE-2024-29025", Default::default(), ()) + .fetch_vulnerability("CVE-2024-29025", Default::default(), &ctx.db) .await? .is_none()); - let affected = service.delete_vulnerability(id, ()).await?; + let affected = service.delete_vulnerability(id, &ctx.db).await?; assert_eq!(0, affected); Ok(()) @@ -341,7 +334,7 @@ async fn delete_vulnerability(ctx: &TrustifyContext) -> Result<(), anyhow::Error #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn vulnerability_queries(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let service = VulnerabilityService::new(ctx.db.clone()); + let service = VulnerabilityService::new(); ctx.ingest_documents([ "csaf/CVE-2023-20862.json", @@ -351,7 +344,7 @@ async fn vulnerability_queries(ctx: &TrustifyContext) -> Result<(), anyhow::Erro .await?; let vulns = service - .fetch_vulnerabilities(q(""), Paginated::default(), Default::default(), ()) + .fetch_vulnerabilities(q(""), Paginated::default(), Default::default(), &ctx.db) .await?; assert_eq!(5, vulns.items.len()); let vulns = service @@ -359,7 +352,7 @@ async fn vulnerability_queries(ctx: &TrustifyContext) -> Result<(), anyhow::Erro q("average_score>9"), Paginated::default(), Default::default(), - (), + &ctx.db, ) .await?; assert_eq!(1, vulns.items.len()); @@ -369,7 +362,7 @@ async fn vulnerability_queries(ctx: &TrustifyContext) -> Result<(), anyhow::Erro q("average_severity=critical"), Paginated::default(), Default::default(), - (), + &ctx.db, ) .await?; assert_eq!(1, vulns.items.len()); @@ -379,7 +372,7 @@ async fn vulnerability_queries(ctx: &TrustifyContext) -> Result<(), anyhow::Erro q("average_severity Result<(), anyhow::Erro q("average_severity>=high"), Paginated::default(), Default::default(), - (), + &ctx.db, ) .await?; assert_eq!(4, vulns.items.len()); let vulns = service - .fetch_vulnerabilities(q("20862"), Paginated::default(), Default::default(), ()) + .fetch_vulnerabilities( + q("20862"), + Paginated::default(), + Default::default(), + &ctx.db, + ) .await?; assert_eq!(1, vulns.items.len()); assert_eq!(vulns.items[0].head.identifier, "CVE-2023-20862"); diff --git a/modules/fundamental/src/weakness/model.rs b/modules/fundamental/src/weakness/model.rs index e2810e1a7..107b49652 100644 --- a/modules/fundamental/src/weakness/model.rs +++ b/modules/fundamental/src/weakness/model.rs @@ -1,6 +1,5 @@ use crate::Error; use serde::{Deserialize, Serialize}; -use trustify_common::db::ConnectionOrTransaction; use trustify_entity::weakness; use utoipa::ToSchema; @@ -17,22 +16,16 @@ pub struct WeaknessSummary { } impl WeaknessSummary { - pub async fn from_entities( - entities: &[weakness::Model], - tx: &ConnectionOrTransaction<'_>, - ) -> Result, Error> { + pub async fn from_entities(entities: &[weakness::Model]) -> Result, Error> { let mut summaries = Vec::new(); for each in entities { - summaries.push(Self::from_entity(each, tx).await?) + summaries.push(Self::from_entity(each).await?) } Ok(summaries) } - pub async fn from_entity( - entity: &weakness::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(entity: &weakness::Model) -> Result { Ok(Self { head: WeaknessHead { id: entity.id.clone(), @@ -59,10 +52,7 @@ pub struct WeaknessDetails { } impl WeaknessDetails { - pub async fn from_entity( - entity: &weakness::Model, - _tx: &ConnectionOrTransaction<'_>, - ) -> Result { + pub async fn from_entity(entity: &weakness::Model) -> Result { Ok(Self { head: WeaknessHead { id: entity.id.clone(), diff --git a/modules/fundamental/src/weakness/service/mod.rs b/modules/fundamental/src/weakness/service/mod.rs index 65e354844..edeec68e0 100644 --- a/modules/fundamental/src/weakness/service/mod.rs +++ b/modules/fundamental/src/weakness/service/mod.rs @@ -2,7 +2,7 @@ use crate::{ weakness::model::{WeaknessDetails, WeaknessSummary}, Error, }; -use sea_orm::{EntityTrait, TransactionTrait}; +use sea_orm::EntityTrait; use trustify_common::{ db::{ limiter::LimiterTrait, @@ -27,9 +27,6 @@ impl WeaknessService { query: Query, paginated: Paginated, ) -> Result, Error> { - let tx = self.db.begin().await?; - let tx = (&tx).into(); - let limiter = weakness::Entity::find().filtering(query)?.limiting( &self.db, paginated.offset, @@ -40,17 +37,14 @@ impl WeaknessService { let items = limiter.fetch().await?; Ok(PaginatedResults { - items: WeaknessSummary::from_entities(&items, &tx).await?, + items: WeaknessSummary::from_entities(&items).await?, total, }) } pub async fn get_weakness(&self, id: &str) -> Result, Error> { - let tx = self.db.begin().await?; - let tx = (&tx).into(); - if let Some(found) = weakness::Entity::find_by_id(id).one(&self.db).await? { - Ok(Some(WeaknessDetails::from_entity(&found, &tx).await?)) + Ok(Some(WeaknessDetails::from_entity(&found).await?)) } else { Ok(None) } diff --git a/modules/fundamental/tests/advisory/csaf/delete.rs b/modules/fundamental/tests/advisory/csaf/delete.rs index e79f291fe..9e68091bb 100644 --- a/modules/fundamental/tests/advisory/csaf/delete.rs +++ b/modules/fundamental/tests/advisory/csaf/delete.rs @@ -30,22 +30,22 @@ async fn simple(ctx: &TrustifyContext) -> anyhow::Result<()> { let service = AdvisoryService::new(ctx.db.clone()); service - .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), ()) + .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), &ctx.db) .await?; // now test, find only one, for either ignore or consider - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2023-33201", Deprecation::Consider, ()) + .fetch_vulnerability("CVE-2023-33201", Deprecation::Consider, &ctx.db) .await? .expect("must exist"); assert_eq!(v.advisories.len(), 1); - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2023-33201", Deprecation::Ignore, ()) + .fetch_vulnerability("CVE-2023-33201", Deprecation::Ignore, &ctx.db) .await? .expect("must exist"); @@ -70,14 +70,14 @@ async fn delete_check_vulns(ctx: &TrustifyContext) -> anyhow::Result<()> { let service = AdvisoryService::new(ctx.db.clone()); service - .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), ()) + .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), &ctx.db) .await?; // check info - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let purls = service - .purls(Default::default(), Default::default(), ()) + .purls(Default::default(), Default::default(), &ctx.db) .await?; // pkg:rpm/redhat/eap7-bouncycastle-util@1.76.0-4.redhat_00001.1.el9eap?arch=noarch @@ -105,7 +105,7 @@ async fn delete_check_vulns(ctx: &TrustifyContext) -> anyhow::Result<()> { // get vuln by purl let mut purl = service - .purl_by_uuid(&purl.head.uuid, Deprecation::Ignore, ()) + .purl_by_uuid(&purl.head.uuid, Deprecation::Ignore, &ctx.db) .await? .expect("must find something"); diff --git a/modules/fundamental/tests/advisory/csaf/reingest.rs b/modules/fundamental/tests/advisory/csaf/reingest.rs index 1cec40d20..838346176 100644 --- a/modules/fundamental/tests/advisory/csaf/reingest.rs +++ b/modules/fundamental/tests/advisory/csaf/reingest.rs @@ -26,9 +26,9 @@ async fn equal(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2023-33201", Default::default(), ()) + .fetch_vulnerability("CVE-2023-33201", Default::default(), &ctx.db) .await? .expect("must exist"); @@ -51,9 +51,9 @@ async fn change_ps_num_advisories(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - non-deprecated - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2023-33201", Deprecation::Ignore, ()) + .fetch_vulnerability("CVE-2023-33201", Deprecation::Ignore, &ctx.db) .await? .expect("must exist"); @@ -61,9 +61,9 @@ async fn change_ps_num_advisories(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - with-deprecated - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2023-33201", Deprecation::Consider, ()) + .fetch_vulnerability("CVE-2023-33201", Deprecation::Consider, &ctx.db) .await? .expect("must exist"); @@ -86,9 +86,9 @@ async fn change_ps_list_vulns(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let purls = service - .purls(Default::default(), Default::default(), ()) + .purls(Default::default(), Default::default(), &ctx.db) .await?; // pkg:rpm/redhat/eap7-bouncycastle@1.76.0-4.redhat_00001.1.el9eap?arch=noarch @@ -121,7 +121,7 @@ async fn change_ps_list_vulns(ctx: &TrustifyContext) -> anyhow::Result<()> { // get vuln by purl let purl = service - .purl_by_uuid(&purl.head.uuid, Deprecation::Ignore, ()) + .purl_by_uuid(&purl.head.uuid, Deprecation::Ignore, &ctx.db) .await? .expect("must find something"); @@ -167,9 +167,9 @@ async fn change_ps_list_vulns_all(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let purls = service - .purls(Default::default(), Default::default(), ()) + .purls(Default::default(), Default::default(), &ctx.db) .await?; // pkg:rpm/redhat/eap7-bouncycastle-util@1.76.0-4.redhat_00001.1.el9eap?arch=noarch @@ -202,7 +202,7 @@ async fn change_ps_list_vulns_all(ctx: &TrustifyContext) -> anyhow::Result<()> { // get vuln by purl let mut purl = service - .purl_by_uuid(&purl.head.uuid, Deprecation::Consider, ()) + .purl_by_uuid(&purl.head.uuid, Deprecation::Consider, &ctx.db) .await? .expect("must find something"); diff --git a/modules/fundamental/tests/advisory/cve/delete.rs b/modules/fundamental/tests/advisory/cve/delete.rs index 1e690fad3..8ece2b3dd 100644 --- a/modules/fundamental/tests/advisory/cve/delete.rs +++ b/modules/fundamental/tests/advisory/cve/delete.rs @@ -13,7 +13,7 @@ use trustify_test_context::TrustifyContext; async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { let (r1, r2) = twice(ctx, |cve| cve, update_mark_rejected).await?; - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); // must be changed @@ -23,13 +23,13 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { let service = AdvisoryService::new(ctx.db.clone()); service - .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), ()) + .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), &ctx.db) .await?; // check info let v = vuln - .fetch_vulnerability("CVE-2021-32714", Deprecation::Ignore, ()) + .fetch_vulnerability("CVE-2021-32714", Deprecation::Ignore, &ctx.db) .await? .expect("must exist"); @@ -42,7 +42,7 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // check with deprecated, should be the same result let v = vuln - .fetch_vulnerability("CVE-2021-32714", Deprecation::Consider, ()) + .fetch_vulnerability("CVE-2021-32714", Deprecation::Consider, &ctx.db) .await? .expect("must exist"); diff --git a/modules/fundamental/tests/advisory/cve/reingest.rs b/modules/fundamental/tests/advisory/cve/reingest.rs index 2e9a9ded8..28a80b66d 100644 --- a/modules/fundamental/tests/advisory/cve/reingest.rs +++ b/modules/fundamental/tests/advisory/cve/reingest.rs @@ -17,9 +17,9 @@ async fn equal(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2021-32714", Default::default(), ()) + .fetch_vulnerability("CVE-2021-32714", Default::default(), &ctx.db) .await? .expect("must exist"); @@ -42,9 +42,9 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // check without deprecated - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2021-32714", Deprecation::Ignore, ()) + .fetch_vulnerability("CVE-2021-32714", Deprecation::Ignore, &ctx.db) .await? .expect("must exist"); @@ -56,9 +56,9 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // check with deprecated - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2021-32714", Deprecation::Consider, ()) + .fetch_vulnerability("CVE-2021-32714", Deprecation::Consider, &ctx.db) .await? .expect("must exist"); diff --git a/modules/fundamental/tests/advisory/osv/delete.rs b/modules/fundamental/tests/advisory/osv/delete.rs index 7f10b44a0..4cad77768 100644 --- a/modules/fundamental/tests/advisory/osv/delete.rs +++ b/modules/fundamental/tests/advisory/osv/delete.rs @@ -13,7 +13,7 @@ use trustify_test_context::TrustifyContext; async fn fixed(ctx: &TrustifyContext) -> anyhow::Result<()> { let (r1, r2) = twice(ctx, update_unmark_fixed, update_mark_fixed_again).await?; - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); // must be changed @@ -23,13 +23,13 @@ async fn fixed(ctx: &TrustifyContext) -> anyhow::Result<()> { let service = AdvisoryService::new(ctx.db.clone()); service - .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), ()) + .delete_advisory(r2.id.try_as_uid().expect("must be a UUID variant"), &ctx.db) .await?; // check info let v = vuln - .fetch_vulnerability("CVE-2020-5238", Deprecation::Ignore, ()) + .fetch_vulnerability("CVE-2020-5238", Deprecation::Ignore, &ctx.db) .await? .expect("must exist"); @@ -38,7 +38,7 @@ async fn fixed(ctx: &TrustifyContext) -> anyhow::Result<()> { // check with deprecated, should be the same result let v = vuln - .fetch_vulnerability("CVE-2020-5238", Deprecation::Consider, ()) + .fetch_vulnerability("CVE-2020-5238", Deprecation::Consider, &ctx.db) .await? .expect("must exist"); diff --git a/modules/fundamental/tests/advisory/osv/reingest.rs b/modules/fundamental/tests/advisory/osv/reingest.rs index 0610f2f5b..698d2868b 100644 --- a/modules/fundamental/tests/advisory/osv/reingest.rs +++ b/modules/fundamental/tests/advisory/osv/reingest.rs @@ -21,9 +21,9 @@ async fn equal(ctx: &TrustifyContext) -> anyhow::Result<()> { // check info - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2020-5238", Default::default(), ()) + .fetch_vulnerability("CVE-2020-5238", Default::default(), &ctx.db) .await? .expect("must exist"); @@ -46,9 +46,9 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // check without deprecated - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2020-5238", Deprecation::Ignore, ()) + .fetch_vulnerability("CVE-2020-5238", Deprecation::Ignore, &ctx.db) .await? .expect("must exist"); @@ -58,9 +58,9 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // check with deprecated - let vuln = VulnerabilityService::new(ctx.db.clone()); + let vuln = VulnerabilityService::new(); let v = vuln - .fetch_vulnerability("CVE-2020-5238", Deprecation::Consider, ()) + .fetch_vulnerability("CVE-2020-5238", Deprecation::Consider, &ctx.db) .await? .expect("must exist"); @@ -71,9 +71,9 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // check status - let service = PurlService::new(ctx.db.clone()); + let service = PurlService::new(); let purls = service - .purls(Default::default(), Default::default(), ()) + .purls(Default::default(), Default::default(), &ctx.db) .await?; let purl = purls @@ -98,7 +98,7 @@ async fn withdrawn(ctx: &TrustifyContext) -> anyhow::Result<()> { // get vuln by purl let mut purl = service - .purl_by_uuid(&purl.head.uuid, Deprecation::Consider, ()) + .purl_by_uuid(&purl.head.uuid, Deprecation::Consider, &ctx.db) .await? .expect("must find something"); diff --git a/modules/fundamental/tests/dataset.rs b/modules/fundamental/tests/dataset.rs index 88268b73c..aca780890 100644 --- a/modules/fundamental/tests/dataset.rs +++ b/modules/fundamental/tests/dataset.rs @@ -63,7 +63,7 @@ async fn ingest(ctx: TrustifyContext) -> anyhow::Result<()> { let sbom = &result.files["spdx/quarkus-bom-2.13.8.Final-redhat-00004.json.bz2"]; assert!(matches!(sbom.id, Id::Uuid(_))); - let sbom_summary = service.fetch_sbom_summary(sbom.id.clone(), ()).await?; + let sbom_summary = service.fetch_sbom_summary(sbom.id.clone(), &ctx.db).await?; assert!(sbom_summary.is_some()); let sbom_summary = sbom_summary.unwrap(); assert_eq!(sbom_summary.head.name, "quarkus-bom"); diff --git a/modules/fundamental/tests/sbom/cyclonedx/cpe.rs b/modules/fundamental/tests/sbom/cyclonedx/cpe.rs index 26b9062c6..ffc76519b 100644 --- a/modules/fundamental/tests/sbom/cyclonedx/cpe.rs +++ b/modules/fundamental/tests/sbom/cyclonedx/cpe.rs @@ -15,7 +15,7 @@ async fn simple(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { .describes_packages( result.id.try_as_uid().expect("Must be a UID"), Default::default(), - (), + &ctx.db, ) .await?; @@ -40,7 +40,7 @@ async fn simple_ref(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { .describes_packages( result.id.try_as_uid().expect("Must be a UID"), Default::default(), - (), + &ctx.db, ) .await?; @@ -65,7 +65,7 @@ async fn simple_comp(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { .describes_packages( result.id.try_as_uid().expect("Must be a UID"), Default::default(), - (), + &ctx.db, ) .await?; @@ -82,7 +82,7 @@ async fn simple_comp(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { result.id.try_as_uid().expect("Must be a UID"), Default::default(), Default::default(), - (), + &ctx.db, ) .await?; diff --git a/modules/fundamental/tests/sbom/cyclonedx/mod.rs b/modules/fundamental/tests/sbom/cyclonedx/mod.rs index cfee1559e..8db78e049 100644 --- a/modules/fundamental/tests/sbom/cyclonedx/mod.rs +++ b/modules/fundamental/tests/sbom/cyclonedx/mod.rs @@ -4,7 +4,6 @@ use super::*; use std::str::FromStr; use test_context::test_context; use test_log::test; -use trustify_common::db::Transactional; use trustify_common::model::Paginated; use trustify_common::purl::Purl; use trustify_module_fundamental::purl::model::summary::purl::PurlSummary; @@ -19,7 +18,7 @@ async fn test_parse_cyclonedx(ctx: &TrustifyContext) -> Result<(), anyhow::Error "zookeeper-3.9.2-cyclonedx.json", |WithContext { service, sbom, .. }| async move { let described = service - .describes_packages(sbom.sbom.sbom_id, Default::default(), Transactional::None) + .describes_packages(sbom.sbom.sbom_id, Default::default(), &ctx.db) .await?; assert_eq!(1, described.items.len()); @@ -71,7 +70,7 @@ async fn test_parse_cyclonedx(ctx: &TrustifyContext) -> Result<(), anyhow::Error offset: 0, limit: 1, }, - (), + &ctx.db, ) .await?; @@ -99,7 +98,7 @@ where ctx, sbom, |data| Ok(Bom::parse_from_json(data)?), - |ctx, sbom, tx| Box::pin(async move { ctx.ingest_cyclonedx(sbom.clone(), &tx).await }), + |ctx, sbom, tx| Box::pin(async move { ctx.ingest_cyclonedx(sbom.clone(), tx).await }), |sbom| sbom::cyclonedx::Information(sbom).into(), f, ) diff --git a/modules/fundamental/tests/sbom/graph.rs b/modules/fundamental/tests/sbom/graph.rs index dcb23223e..07ddb2357 100644 --- a/modules/fundamental/tests/sbom/graph.rs +++ b/modules/fundamental/tests/sbom/graph.rs @@ -2,7 +2,6 @@ use std::convert::TryInto; use std::str::FromStr; use test_context::test_context; use test_log::test; -use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_common::purl::Purl; use trustify_common::sbom::SbomLocator; @@ -23,7 +22,7 @@ async fn ingest_sboms(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { &Digests::digest("8"), "a", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v1_again = system @@ -32,7 +31,7 @@ async fn ingest_sboms(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { &Digests::digest("8"), "b", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v2 = system @@ -41,7 +40,7 @@ async fn ingest_sboms(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { &Digests::digest("9"), "c", (), - Transactional::None, + &ctx.db, ) .await?; @@ -51,7 +50,7 @@ async fn ingest_sboms(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { &Digests::digest("10"), "d", (), - Transactional::None, + &ctx.db, ) .await?; @@ -74,7 +73,7 @@ async fn ingest_and_fetch_sboms_describing_purls( &Digests::digest("8"), "a", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v2 = system @@ -83,7 +82,7 @@ async fn ingest_and_fetch_sboms_describing_purls( &Digests::digest("9"), "b", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v3 = system @@ -92,35 +91,35 @@ async fn ingest_and_fetch_sboms_describing_purls( &Digests::digest("10"), "c", (), - Transactional::None, + &ctx.db, ) .await?; sbom_v1 .ingest_describes_package( "pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?, - Transactional::None, + &ctx.db, ) .await?; sbom_v2 .ingest_describes_package( "pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?, - Transactional::None, + &ctx.db, ) .await?; sbom_v3 .ingest_describes_package( "pkg:maven/io.quarkus/quarkus-core@1.9.3".try_into()?, - Transactional::None, + &ctx.db, ) .await?; let found = system .locate_sboms( SbomLocator::Purl("pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?), - Transactional::None, + &ctx.db, ) .await?; @@ -144,7 +143,7 @@ async fn ingest_and_locate_sboms_describing_cpes( &Digests::digest("8"), "a", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v2 = system @@ -153,7 +152,7 @@ async fn ingest_and_locate_sboms_describing_cpes( &Digests::digest("9"), "b", (), - Transactional::None, + &ctx.db, ) .await?; let sbom_v3 = system @@ -162,35 +161,26 @@ async fn ingest_and_locate_sboms_describing_cpes( &Digests::digest("10"), "c", (), - Transactional::None, + &ctx.db, ) .await?; sbom_v1 - .ingest_describes_cpe22( - "cpe:/a:redhat:quarkus:2.13::el8".parse()?, - Transactional::None, - ) + .ingest_describes_cpe22("cpe:/a:redhat:quarkus:2.13::el8".parse()?, &ctx.db) .await?; sbom_v2 - .ingest_describes_cpe22( - "cpe:/a:redhat:quarkus:2.13::el8".parse()?, - Transactional::None, - ) + .ingest_describes_cpe22("cpe:/a:redhat:quarkus:2.13::el8".parse()?, &ctx.db) .await?; sbom_v3 - .ingest_describes_cpe22( - "cpe:/a:redhat:not-quarkus:2.13::el8".parse()?, - Transactional::None, - ) + .ingest_describes_cpe22("cpe:/a:redhat:not-quarkus:2.13::el8".parse()?, &ctx.db) .await?; let found = system .locate_sboms( SbomLocator::Cpe("cpe:/a:redhat:quarkus:2.13::el8".parse()?), - Transactional::None, + &ctx.db, ) .await?; @@ -212,7 +202,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E &Digests::digest("8675309"), "a", (), - Transactional::None, + &ctx.db, ) .await?; @@ -221,7 +211,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E Purl::from_str("pkg:maven/io.quarkus/transitive-b@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/transitive-a@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -230,7 +220,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E Purl::from_str("pkg:maven/io.quarkus/transitive-c@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/transitive-b@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -239,7 +229,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E Purl::from_str("pkg:maven/io.quarkus/transitive-d@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/transitive-c@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -248,7 +238,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E Purl::from_str("pkg:maven/io.quarkus/transitive-e@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/transitive-c@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -257,7 +247,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E Purl::from_str("pkg:maven/io.quarkus/transitive-d@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/transitive-b@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -265,7 +255,7 @@ async fn transitive_dependency_of(ctx: &TrustifyContext) -> Result<(), anyhow::E .related_packages_transitively( &[Relationship::DependencyOf], &"pkg:maven/io.quarkus/transitive-a@1.2.3".try_into()?, - Transactional::None, + &ctx.db, ) .await?; @@ -286,7 +276,7 @@ async fn ingest_package_relates_to_package_dependency_of( &Digests::digest("8675309"), "a", (), - Transactional::None, + &ctx.db, ) .await?; @@ -295,7 +285,7 @@ async fn ingest_package_relates_to_package_dependency_of( Purl::from_str("pkg:maven/io.quarkus/quarkus-postgres@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/quarkus-core@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -305,7 +295,7 @@ async fn ingest_package_relates_to_package_dependency_of( &Digests::digest("8675308"), "b", (), - Transactional::None, + &ctx.db, ) .await?; @@ -314,7 +304,7 @@ async fn ingest_package_relates_to_package_dependency_of( Purl::from_str("pkg:maven/io.quarkus/quarkus-sqlite@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/quarkus-core@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -323,7 +313,7 @@ async fn ingest_package_relates_to_package_dependency_of( sbom1.sbom.sbom_id, Relationship::DependencyOf, "pkg:maven/io.quarkus/quarkus-core@1.2.3", - Transactional::None, + &ctx.db, ) .await?; @@ -346,7 +336,7 @@ async fn ingest_package_relates_to_package_dependency_of( sbom2.sbom.sbom_id, Relationship::DependencyOf, "pkg:maven/io.quarkus/quarkus-core@1.2.3", - Transactional::None, + &ctx.db, ) .await?; @@ -380,13 +370,13 @@ async fn sbom_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error &Digests::digest("8675309"), "a", (), - Transactional::None, + &ctx.db, ) .await?; log::debug!("-------------------- A"); - sbom.ingest_describes_package("pkg:oci/my-app@1.2.3".try_into()?, Transactional::None) + sbom.ingest_describes_package("pkg:oci/my-app@1.2.3".try_into()?, &ctx.db) .await?; log::debug!("-------------------- B"); @@ -394,7 +384,7 @@ async fn sbom_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error Purl::from_str("pkg:maven/io.quarkus/quarkus-core@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:oci/my-app@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; log::debug!("-------------------- C"); @@ -403,7 +393,7 @@ async fn sbom_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error Purl::from_str("pkg:maven/io.quarkus/quarkus-postgres@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/quarkus-core@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; log::debug!("-------------------- D"); @@ -412,7 +402,7 @@ async fn sbom_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error Purl::from_str("pkg:maven/postgres/postgres-driver@1.2.3")?, Relationship::DependencyOf, Purl::from_str("pkg:maven/io.quarkus/quarkus-postgres@1.2.3")?, - Transactional::None, + &ctx.db, ) .await?; @@ -422,12 +412,12 @@ async fn sbom_vulnerabilities(ctx: &TrustifyContext) -> Result<(), anyhow::Error ("source", "http://redhat.com/secdata/RHSA-1"), &Digests::digest("7"), (), - Transactional::None, + &ctx.db, ) .await?; let _advisory_vulnerability = advisory - .link_to_vulnerability("CVE-00000001", None, Transactional::None) + .link_to_vulnerability("CVE-00000001", None, &ctx.db) .await?; Ok(()) diff --git a/modules/fundamental/tests/sbom/mod.rs b/modules/fundamental/tests/sbom/mod.rs index 9069420e5..83fe6cd07 100644 --- a/modules/fundamental/tests/sbom/mod.rs +++ b/modules/fundamental/tests/sbom/mod.rs @@ -6,12 +6,10 @@ mod reingest; mod spdx; use cyclonedx_bom::prelude::Bom; +use sea_orm::{DatabaseTransaction, TransactionTrait}; use std::{future::Future, pin::Pin, time::Instant}; use tracing::{info_span, instrument, Instrument}; -use trustify_common::{ - db::{Database, Transactional}, - hashing::Digests, -}; +use trustify_common::{db::Database, hashing::Digests}; use trustify_module_fundamental::sbom::service::SbomService; use trustify_module_ingestor::{ graph::{ @@ -44,7 +42,7 @@ where for<'a> I: FnOnce( &'a SbomContext, B, - &'a Transactional, + &'a DatabaseTransaction, ) -> Pin> + 'a>>, C: FnOnce(&B) -> SbomInformation, F: FnOnce(WithContext) -> FFut, @@ -68,7 +66,7 @@ where let parse_time = start.elapsed(); - let tx = graph.transaction().await?; + let tx = db.begin().await?; let start = Instant::now(); let ctx = graph diff --git a/modules/fundamental/tests/sbom/reingest.rs b/modules/fundamental/tests/sbom/reingest.rs index fb296b669..bd54c0903 100644 --- a/modules/fundamental/tests/sbom/reingest.rs +++ b/modules/fundamental/tests/sbom/reingest.rs @@ -55,13 +55,13 @@ async fn quarkus(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_ne!(result1.id, result2.id); let mut sbom1 = sbom - .fetch_sbom_details(result1.id, ()) + .fetch_sbom_details(result1.id, &ctx.db) .await? .expect("v1 must be found"); log::info!("SBOM1: {sbom1:?}"); let mut sbom2 = sbom - .fetch_sbom_details(result2.id, ()) + .fetch_sbom_details(result2.id, &ctx.db) .await? .expect("v2 must be found"); log::info!("SBOM2: {sbom2:?}"); @@ -89,7 +89,7 @@ async fn quarkus(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { Purl::from_str(purl).expect("must parse").qualifier_uuid(), Paginated::default(), Query::default(), - (), + &ctx.db, ) .await?; @@ -129,13 +129,13 @@ async fn nhc(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_ne!(result1.id, result2.id); let mut sbom1 = sbom - .fetch_sbom_details(result1.id, ()) + .fetch_sbom_details(result1.id, &ctx.db) .await? .expect("v1 must be found"); log::info!("SBOM1: {sbom1:?}"); let mut sbom2 = sbom - .fetch_sbom_details(result2.id, ()) + .fetch_sbom_details(result2.id, &ctx.db) .await? .expect("v2 must be found"); log::info!("SBOM2: {sbom2:?}"); @@ -182,13 +182,13 @@ async fn nhc_same(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!(result1.id, result2.id); let mut sbom1 = sbom - .fetch_sbom_details(result1.id, ()) + .fetch_sbom_details(result1.id, &ctx.db) .await? .expect("v1 must be found"); log::info!("SBOM1: {sbom1:?}"); let mut sbom2 = sbom - .fetch_sbom_details(result2.id, ()) + .fetch_sbom_details(result2.id, &ctx.db) .await? .expect("v2 must be found"); log::info!("SBOM2: {sbom2:?}"); @@ -249,13 +249,13 @@ async fn nhc_same_content(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_ne!(result1.id, result2.id); let mut sbom1 = sbom - .fetch_sbom_details(result1.id, ()) + .fetch_sbom_details(result1.id, &ctx.db) .await? .expect("v1 must be found"); log::info!("SBOM1: {sbom1:?}"); let mut sbom2 = sbom - .fetch_sbom_details(result2.id, ()) + .fetch_sbom_details(result2.id, &ctx.db) .await? .expect("v2 must be found"); log::info!("SBOM2: {sbom2:?}"); @@ -306,13 +306,13 @@ async fn syft_rerun(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_ne!(result1.id, result2.id); let mut sbom1 = sbom - .fetch_sbom_details(result1.id, ()) + .fetch_sbom_details(result1.id, &ctx.db) .await? .expect("v1 must be found"); log::info!("SBOM1: {sbom1:?}"); let mut sbom2 = sbom - .fetch_sbom_details(result2.id, ()) + .fetch_sbom_details(result2.id, &ctx.db) .await? .expect("v2 must be found"); log::info!("SBOM2: {sbom2:?}"); diff --git a/modules/fundamental/tests/sbom/spdx.rs b/modules/fundamental/tests/sbom/spdx.rs index b81aeab69..4140317de 100644 --- a/modules/fundamental/tests/sbom/spdx.rs +++ b/modules/fundamental/tests/sbom/spdx.rs @@ -8,7 +8,7 @@ use std::str::FromStr; use test_context::test_context; use test_log::test; use tracing::instrument; -use trustify_common::{db::Transactional, purl::Purl}; +use trustify_common::purl::Purl; use trustify_entity::relationship::Relationship; use trustify_module_fundamental::{ purl::model::{summary::purl::PurlSummary, PurlHead}, @@ -25,7 +25,7 @@ async fn parse_spdx_quarkus(ctx: &TrustifyContext) -> Result<(), anyhow::Error> "quarkus/v1/quarkus-bom-2.13.8.Final-redhat-00004.json", |WithContext { service, sbom, .. }| async move { let described = service - .describes_packages(sbom.sbom.sbom_id, Default::default(), Transactional::None) + .describes_packages(sbom.sbom.sbom_id, Default::default(), &ctx.db) .await?; log::debug!("{:#?}", described); assert_eq!(1, described.items.len()); @@ -52,7 +52,7 @@ async fn parse_spdx_quarkus(ctx: &TrustifyContext) -> Result<(), anyhow::Error> sbom.sbom.sbom_id, Relationship::ContainedBy, first, - Transactional::None, + &ctx.db, ) .await?; @@ -73,7 +73,7 @@ async fn test_parse_spdx(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { "ubi9-9.2-755.1697625012.json", |WithContext { service, sbom, .. }| async move { let described = service - .describes_packages(sbom.sbom.sbom_id, Default::default(), Transactional::None) + .describes_packages(sbom.sbom.sbom_id, Default::default(), &ctx.db) .await?; assert_eq!(1, described.total); @@ -87,7 +87,7 @@ async fn test_parse_spdx(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { Which::Right, first, Some(Relationship::ContainedBy), - (), + &ctx.db, ) .await? .items; @@ -118,7 +118,7 @@ async fn ingest_spdx_broken_refs(ctx: &TrustifyContext) -> Result<(), anyhow::Er ); let result = sbom - .fetch_sboms(Default::default(), Default::default(), (), ()) + .fetch_sboms(Default::default(), Default::default(), (), &ctx.db) .await?; // there must be no traces, everything must be rolled back @@ -143,7 +143,7 @@ where }, |ctx, sbom, tx| { Box::pin(async move { - ctx.ingest_spdx(sbom.clone(), &Discard, &tx).await?; + ctx.ingest_spdx(sbom.clone(), &Discard, tx).await?; Ok(()) }) }, diff --git a/modules/fundamental/tests/sbom/spdx/corner_cases.rs b/modules/fundamental/tests/sbom/spdx/corner_cases.rs index c2c577894..099087e2d 100644 --- a/modules/fundamental/tests/sbom/spdx/corner_cases.rs +++ b/modules/fundamental/tests/sbom/spdx/corner_cases.rs @@ -1,6 +1,7 @@ #![allow(clippy::expect_used)] use anyhow::bail; +use sea_orm::ConnectionTrait; use strum::VariantArray; use test_context::test_context; use test_log::test; @@ -12,13 +13,14 @@ use trustify_module_ingestor::graph::purl::qualified_package::QualifiedPackageCo use trustify_module_ingestor::graph::sbom::SbomContext; use trustify_test_context::TrustifyContext; -async fn related_packages_transitively( - sbom: &SbomContext, -) -> Result, anyhow::Error> { +async fn related_packages_transitively<'a, C: ConnectionTrait>( + sbom: &'a SbomContext, + connection: &C, +) -> Result>, anyhow::Error> { let purl = Purl::try_from("pkg:cargo/A@0.0.0").expect("must parse"); let result = sbom - .related_packages_transitively(Relationship::VARIANTS, &purl, ()) + .related_packages_transitively(Relationship::VARIANTS, &purl, connection) .await?; Ok(result) @@ -37,28 +39,28 @@ async fn infinite_loop(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { let sbom = ctx .graph - .get_sbom_by_id(id, ()) + .get_sbom_by_id(id, &ctx.db) .await? .expect("must be found"); let packages = service - .fetch_sbom_packages(id, Default::default(), Default::default(), ()) + .fetch_sbom_packages(id, Default::default(), Default::default(), &ctx.db) .await?; assert_eq!(packages.total, 3); - let packages = related_packages_transitively(&sbom).await?; + let packages = related_packages_transitively(&sbom, &ctx.db).await?; assert_eq!(packages.len(), 3); let packages = service - .describes_packages(id, Default::default(), ()) + .describes_packages(id, Default::default(), &ctx.db) .await?; assert_eq!(packages.total, 1); let packages = service - .related_packages(id, None, SbomPackageReference::All, ()) + .related_packages(id, None, SbomPackageReference::All, &ctx.db) .await?; log::info!("Packages: {packages:#?}"); @@ -78,23 +80,23 @@ async fn double_ref(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { }; let sbom = ctx .graph - .get_sbom_by_id(id, ()) + .get_sbom_by_id(id, &ctx.db) .await? .expect("must be found"); let service = SbomService::new(ctx.db.clone()); let packages = service - .fetch_sbom_packages(id, Default::default(), Default::default(), ()) + .fetch_sbom_packages(id, Default::default(), Default::default(), &ctx.db) .await?; assert_eq!(packages.total, 3); - let packages = related_packages_transitively(&sbom).await?; + let packages = related_packages_transitively(&sbom, &ctx.db).await?; assert_eq!(packages.len(), 3); let packages = service - .related_packages(id, None, SbomPackageReference::All, ()) + .related_packages(id, None, SbomPackageReference::All, &ctx.db) .await?; log::info!("Packages: {packages:#?}"); @@ -114,23 +116,23 @@ async fn self_ref(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { }; let sbom = ctx .graph - .get_sbom_by_id(id, ()) + .get_sbom_by_id(id, &ctx.db) .await? .expect("must be found"); let service = SbomService::new(ctx.db.clone()); let packages = service - .fetch_sbom_packages(id, Default::default(), Default::default(), ()) + .fetch_sbom_packages(id, Default::default(), Default::default(), &ctx.db) .await?; assert_eq!(packages.total, 0); - let packages = related_packages_transitively(&sbom).await?; + let packages = related_packages_transitively(&sbom, &ctx.db).await?; assert_eq!(packages.len(), 0); let packages = service - .related_packages(id, None, SbomPackageReference::All, ()) + .related_packages(id, None, SbomPackageReference::All, &ctx.db) .await?; log::info!("Packages: {packages:#?}"); @@ -150,23 +152,23 @@ async fn self_ref_package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { }; let sbom = ctx .graph - .get_sbom_by_id(id, ()) + .get_sbom_by_id(id, &ctx.db) .await? .expect("must be found"); let service = SbomService::new(ctx.db.clone()); let packages = service - .fetch_sbom_packages(id, Default::default(), Default::default(), ()) + .fetch_sbom_packages(id, Default::default(), Default::default(), &ctx.db) .await?; assert_eq!(packages.total, 1); - let packages = related_packages_transitively(&sbom).await?; + let packages = related_packages_transitively(&sbom, &ctx.db).await?; assert_eq!(packages.len(), 1); let packages = service - .related_packages(id, None, SbomPackageReference::All, ()) + .related_packages(id, None, SbomPackageReference::All, &ctx.db) .await?; log::info!("Packages: {packages:#?}"); @@ -174,7 +176,12 @@ async fn self_ref_package(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { assert_eq!(packages.len(), 1); let packages = service - .related_packages(id, None, SbomPackageReference::Package("SPDXRef-A"), ()) + .related_packages( + id, + None, + SbomPackageReference::Package("SPDXRef-A"), + &ctx.db, + ) .await?; log::info!("Packages: {packages:#?}"); diff --git a/modules/fundamental/tests/sbom/spdx/perf.rs b/modules/fundamental/tests/sbom/spdx/perf.rs index 6fb0572ad..52bcf930d 100644 --- a/modules/fundamental/tests/sbom/spdx/perf.rs +++ b/modules/fundamental/tests/sbom/spdx/perf.rs @@ -2,7 +2,7 @@ use super::*; use test_context::test_context; use test_log::test; use tracing::instrument; -use trustify_common::{db::Transactional, model::Paginated}; +use trustify_common::model::Paginated; use trustify_module_fundamental::sbom::model::SbomPackage; use trustify_test_context::TrustifyContext; @@ -15,7 +15,7 @@ async fn ingest_spdx_medium(ctx: &TrustifyContext) -> Result<(), anyhow::Error> "openshift-container-storage-4.8.z.json.xz", |WithContext { service, sbom, .. }| async move { let described = service - .describes_packages(sbom.sbom.sbom_id, Default::default(), ()) + .describes_packages(sbom.sbom.sbom_id, Default::default(), &ctx.db) .await?; log::debug!("{:#?}", described); @@ -39,7 +39,7 @@ async fn ingest_spdx_medium(ctx: &TrustifyContext) -> Result<(), anyhow::Error> offset: 0, limit: 1, }, - (), + &ctx.db, ) .await?; assert_eq!(1, packages.items.len()); @@ -61,7 +61,7 @@ async fn ingest_spdx_large(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { "openshift-4.13.json.xz", |WithContext { service, sbom, .. }| async move { let described = service - .describes_packages(sbom.sbom.sbom_id, Default::default(), Transactional::None) + .describes_packages(sbom.sbom.sbom_id, Default::default(), &ctx.db) .await?; log::debug!("{:#?}", described); assert_eq!(1, described.items.len()); @@ -85,7 +85,7 @@ async fn ingest_spdx_medium_cpes(ctx: &TrustifyContext) -> Result<(), anyhow::Er "rhel-br-9.2.0.json.xz", |WithContext { service, sbom, .. }| async move { let described = service - .describes_packages(sbom.sbom.sbom_id, Default::default(), ()) + .describes_packages(sbom.sbom.sbom_id, Default::default(), &ctx.db) .await?; log::debug!("{:#?}", described); @@ -109,7 +109,7 @@ async fn ingest_spdx_medium_cpes(ctx: &TrustifyContext) -> Result<(), anyhow::Er offset: 0, limit: 1, }, - (), + &ctx.db, ) .await?; assert_eq!(1, packages.items.len()); diff --git a/modules/graphql/src/advisory.rs b/modules/graphql/src/advisory.rs index 6b30b3b7f..ff8941301 100644 --- a/modules/graphql/src/advisory.rs +++ b/modules/graphql/src/advisory.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use async_graphql::{Context, FieldError, FieldResult, Object}; -use trustify_common::db::Transactional; +use trustify_common::db::Database; use trustify_entity::advisory::Model as Advisory; use trustify_module_ingestor::graph::Graph; use uuid::Uuid; @@ -12,8 +12,9 @@ pub struct AdvisoryQuery; #[Object] impl AdvisoryQuery { async fn get_advisory_by_id<'a>(&self, ctx: &Context<'a>, id: Uuid) -> FieldResult { + let db = ctx.data::>()?; let graph = ctx.data::>()?; - let advisory = graph.get_advisory_by_id(id, Transactional::None).await; + let advisory = graph.get_advisory_by_id(id, db.as_ref()).await; match advisory { Ok(Some(advisory)) => Ok(Advisory { @@ -36,10 +37,11 @@ impl AdvisoryQuery { } async fn get_advisories<'a>(&self, ctx: &Context<'a>) -> FieldResult> { + let db = ctx.data::>()?; let graph = ctx.data::>()?; let advisories = graph - .get_advisories(Default::default(), Transactional::None) + .get_advisories(Default::default(), db.as_ref()) .await .unwrap_or_default(); diff --git a/modules/graphql/src/organization.rs b/modules/graphql/src/organization.rs index cb8655186..e528dd375 100644 --- a/modules/graphql/src/organization.rs +++ b/modules/graphql/src/organization.rs @@ -1,7 +1,6 @@ -use std::sync::Arc; - use async_graphql::{Context, FieldError, FieldResult, Object}; -use trustify_common::db::Transactional; +use std::sync::Arc; +use trustify_common::db::Database; use trustify_entity::organization::Model as Organization; use trustify_module_ingestor::graph::Graph; @@ -15,10 +14,9 @@ impl OrganizationQuery { ctx: &Context<'a>, name: String, ) -> FieldResult { + let db = ctx.data::>()?; let graph = ctx.data::>()?; - let organization = graph - .get_organization_by_name(name, Transactional::None) - .await; + let organization = graph.get_organization_by_name(name, db.as_ref()).await; match organization { Ok(Some(organization)) => Ok(Organization { diff --git a/modules/graphql/src/sbom.rs b/modules/graphql/src/sbom.rs index 344027254..74badbcaf 100644 --- a/modules/graphql/src/sbom.rs +++ b/modules/graphql/src/sbom.rs @@ -1,9 +1,7 @@ -use std::sync::Arc; - use async_graphql::{Context, FieldError, FieldResult, Object}; -use trustify_common::db::Transactional; -use trustify_entity::labels::Labels; -use trustify_entity::sbom::Model as Sbom; +use std::sync::Arc; +use trustify_common::db::Database; +use trustify_entity::{labels::Labels, sbom::Model as Sbom}; use trustify_module_ingestor::graph::Graph; use uuid::Uuid; @@ -13,8 +11,9 @@ pub struct SbomQuery; #[Object] impl SbomQuery { async fn get_sbom_by_id<'a>(&self, ctx: &Context<'a>, id: Uuid) -> FieldResult { + let db = ctx.data::>()?; let graph = ctx.data::>()?; - let sbom = graph.locate_sbom_by_id(id, Transactional::None).await; + let sbom = graph.locate_sbom_by_id(id, db.as_ref()).await; match sbom { Ok(Some(sbom_context)) => Ok(Sbom { @@ -37,6 +36,7 @@ impl SbomQuery { ctx: &Context<'a>, labels: String, ) -> FieldResult> { + let db = ctx.data::>()?; let graph = ctx.data::>()?; let mut local_labels = Labels::new(); @@ -50,7 +50,7 @@ impl SbomQuery { } let sboms = graph - .locate_sboms_by_labels(local_labels, Transactional::None) + .locate_sboms_by_labels(local_labels, db.as_ref()) .await .unwrap_or_default(); diff --git a/modules/graphql/src/sbomstatus.rs b/modules/graphql/src/sbomstatus.rs index b9488bd69..23707e77a 100644 --- a/modules/graphql/src/sbomstatus.rs +++ b/modules/graphql/src/sbomstatus.rs @@ -1,7 +1,7 @@ use async_graphql::{Context, FieldResult, Object, SimpleObject}; use std::{ops::Deref, sync::Arc}; use trustify_common::{ - db::{self, Transactional}, + db::{self}, id::Id, }; use trustify_module_fundamental::{ @@ -30,7 +30,7 @@ impl SbomStatusQuery { let sbom_service = SbomService::new(db.deref().clone()); let sbom_details: Option = sbom_service - .fetch_sbom_details(Id::Uuid(id), Transactional::None) + .fetch_sbom_details(Id::Uuid(id), db.as_ref()) .await .unwrap_or_default(); diff --git a/modules/graphql/src/vulnerability.rs b/modules/graphql/src/vulnerability.rs index b21dd2463..488c1de7a 100644 --- a/modules/graphql/src/vulnerability.rs +++ b/modules/graphql/src/vulnerability.rs @@ -1,7 +1,6 @@ -use std::sync::Arc; - use async_graphql::{Context, FieldError, FieldResult, Object}; -use trustify_common::db::Transactional; +use std::sync::Arc; +use trustify_common::db::Database; use trustify_entity::vulnerability::Model as Vulnerability; use trustify_module_ingestor::graph::Graph; @@ -15,10 +14,9 @@ impl VulnerabilityQuery { ctx: &Context<'a>, identifier: String, ) -> FieldResult { + let db = ctx.data::>()?; let graph = ctx.data::>()?; - let vulnerability = graph - .get_vulnerability(&identifier, Transactional::None) - .await; + let vulnerability = graph.get_vulnerability(&identifier, db.as_ref()).await; match vulnerability { Ok(Some(vulnerability)) => Ok(Vulnerability { @@ -36,9 +34,10 @@ impl VulnerabilityQuery { } async fn get_vulnerabilities<'a>(&self, ctx: &Context<'a>) -> FieldResult> { + let db = ctx.data::>()?; let graph = ctx.data::>()?; let vulnerabilities = graph - .get_vulnerabilities(Transactional::None) + .get_vulnerabilities(db.as_ref()) .await .unwrap_or_default(); diff --git a/modules/ingestor/src/graph/advisory/advisory_vulnerability.rs b/modules/ingestor/src/graph/advisory/advisory_vulnerability.rs index 8c9c0e02e..9ce1e66fe 100644 --- a/modules/ingestor/src/graph/advisory/advisory_vulnerability.rs +++ b/modules/ingestor/src/graph/advisory/advisory_vulnerability.rs @@ -1,8 +1,11 @@ use crate::graph::{advisory::AdvisoryContext, error::Error, vulnerability::VulnerabilityContext}; -use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoIdentity, NotSet, QueryFilter, Set}; +use sea_orm::{ + ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoIdentity, NotSet, QueryFilter, + Set, +}; use sea_query::{Condition, Expr, IntoCondition}; use tracing::instrument; -use trustify_common::{cpe::Cpe, db::Transactional, purl::Purl}; +use trustify_common::{cpe::Cpe, purl::Purl}; use trustify_cvss::cvss3::Cvss3Base; use trustify_entity::{ self as entity, cvss3::Severity, purl_status, status, version_range, @@ -130,25 +133,25 @@ impl<'g> From<(&AdvisoryContext<'g>, entity::advisory_vulnerability::Model)> } impl<'g> AdvisoryVulnerabilityContext<'g> { - pub async fn vulnerability>( + pub async fn vulnerability( &self, - tx: TX, + connection: &C, ) -> Result, Error> { Ok( vulnerability::Entity::find_by_id(&self.advisory_vulnerability.vulnerability_id) - .one(&self.advisory.graph.connection(&tx)) + .one(connection) .await? .map(|vuln| VulnerabilityContext::new(self.advisory.graph, vuln)), ) } /* - pub async fn get_fixed_package_version>( + pub async fn get_fixed_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(package_version) = self.advisory.graph.get_package_version(purl, &tx).await? { + if let Some(package_version) = self.advisory.graph.get_package_version(purl, connection).await? { Ok(entity::fixed_package_version::Entity::find() .filter( entity::fixed_package_version::Column::AdvisoryId.eq(self.advisory.advisory.id), @@ -157,7 +160,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::fixed_package_version::Column::PackageVersionId .eq(package_version.package_version.id), ) - .one(&self.advisory.graph.connection(&tx)) + .one(&self.advisory.graph.connection(connection)) .await? .map(|affected| FixedPackageVersionContext::new(self, affected))) } else { @@ -165,12 +168,12 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { } } - pub async fn get_not_affected_package_version>( + pub async fn get_not_affected_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(package_version) = self.advisory.graph.get_package_version(purl, &tx).await? { + if let Some(package_version) = self.advisory.graph.get_package_version(purl, connection).await? { Ok(entity::not_affected_package_version::Entity::find() .filter( entity::not_affected_package_version::Column::AdvisoryId @@ -180,7 +183,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::not_affected_package_version::Column::PackageVersionId .eq(package_version.package_version.id), ) - .one(&self.advisory.graph.connection(&tx)) + .one(&self.advisory.graph.connection(connection)) .await? .map(|not_affected_package_version| { NotAffectedPackageVersionContext::new(self, not_affected_package_version) @@ -190,17 +193,17 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { } } - pub async fn get_affected_package_range>( + pub async fn get_affected_package_range( &self, purl: &Purl, start: &str, end: &str, - tx: TX, + connection: &C, ) -> Result, Error> { if let Some(package_version_range) = self .advisory .graph - .get_package_version_range(purl, start, end, &tx) + .get_package_version_range(purl, start, end, connection) .await? { Ok(entity::affected_package_version_range::Entity::find() @@ -212,7 +215,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::affected_package_version_range::Column::PackageVersionRangeId .eq(package_version_range.package_version_range.id), ) - .one(&self.advisory.graph.connection(&tx)) + .one(&self.advisory.graph.connection(connection)) .await? .map(|affected| AffectedPackageVersionRangeContext::new(self, affected))) } else { @@ -220,19 +223,19 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { } } - pub async fn ingest_not_affected_package_version>( + pub async fn ingest_not_affected_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result { - if let Some(found) = self.get_not_affected_package_version(purl, &tx).await? { + if let Some(found) = self.get_not_affected_package_version(purl, connection).await? { return Ok(found); } let package_version = self .advisory .graph - .ingest_package_version(purl, &tx) + .ingest_package_version(purl, connection) .await?; let entity = entity::not_affected_package_version::ActiveModel { @@ -244,23 +247,23 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { Ok(NotAffectedPackageVersionContext::new( self, - entity.insert(&self.advisory.graph.connection(&tx)).await?, + entity.insert(&self.advisory.graph.connection(connection)).await?, )) } - pub async fn ingest_fixed_package_version>( + pub async fn ingest_fixed_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result { - if let Some(found) = self.get_fixed_package_version(purl, &tx).await? { + if let Some(found) = self.get_fixed_package_version(purl, connection).await? { return Ok(found); } let package_version = self .advisory .graph - .ingest_package_version(purl, &tx) + .ingest_package_version(purl, connection) .await?; let entity = entity::fixed_package_version::ActiveModel { @@ -272,30 +275,28 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { Ok(FixedPackageVersionContext::new( self, - entity.insert(&self.advisory.graph.connection(&tx)).await?, + entity.insert(&self.advisory.graph.connection(connection)).await?, )) } */ - #[instrument(skip(self, tx), ret)] - pub async fn ingest_package_status>( + #[instrument(skip(self, connection), ret)] + pub async fn ingest_package_status( &self, cpe_context: Option, purl: &Purl, status: &str, info: VersionInfo, - tx: TX, + connection: &C, ) -> Result<(), Error> { - let connection = self.advisory.graph.connection(&tx); - let status = status::Entity::find() .filter(status::Column::Slug.eq(status)) - .one(&connection) + .one(connection) .await? .ok_or(Error::InvalidStatus(status.to_string()))?; - let package = self.advisory.graph.ingest_package(purl, &tx).await?; + let package = self.advisory.graph.ingest_package(purl, connection).await?; let package_status = purl_status::Entity::find() .filter(purl_status::Column::BasePurlId.eq(package.base_purl.id)) @@ -303,7 +304,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { .filter(purl_status::Column::StatusId.eq(status.id)) .left_join(version_range::Entity) .filter(info.clone().into_condition()) - .one(&connection) + .one(connection) .await?; if package_status.is_some() { @@ -312,7 +313,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { let version_range = info.into_active_model(); - let version_range = version_range.insert(&connection).await?; + let version_range = version_range.insert(connection).await?; let package_status = purl_status::ActiveModel { id: Default::default(), @@ -324,22 +325,22 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { context_cpe_id: NotSet, }; - package_status.insert(&connection).await?; + package_status.insert(connection).await?; Ok(()) } /* #[instrument(skip(self, tx), err)] - pub async fn ingest_affected_package_range>( + pub async fn ingest_affected_package_range( &self, purl: &Purl, start: &str, end: &str, - tx: TX, + connection: &C, ) -> Result { if let Some(found) = self - .get_affected_package_range(purl, start, end, &tx) + .get_affected_package_range(purl, start, end, connection) .await? { return Ok(found); @@ -348,7 +349,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { let package_version_range = self .advisory .graph - .ingest_package_version_range(purl, start, end, &tx) + .ingest_package_version_range(purl, start, end, connection) .await?; let entity = entity::affected_package_version_range::ActiveModel { @@ -360,15 +361,15 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { Ok(AffectedPackageVersionRangeContext::new( self, - entity.insert(&self.advisory.graph.connection(&tx)).await?, + entity.insert(&self.advisory.graph.connection(connection)).await?, )) } */ - pub async fn cvss3_scores>( + pub async fn cvss3_scores( &self, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(entity::cvss3::Entity::find() .filter(entity::cvss3::Column::AdvisoryId.eq(self.advisory_vulnerability.advisory_id)) @@ -376,17 +377,17 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { entity::cvss3::Column::VulnerabilityId .eq(self.advisory_vulnerability.vulnerability_id.clone()), ) - .all(&self.advisory.graph.connection(&tx)) + .all(connection) .await? .drain(..) .map(|e| e.into()) .collect()) } - pub async fn get_cvss3_score>( + pub async fn get_cvss3_score( &self, minor_version: u8, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(entity::cvss3::Entity::find() .filter(entity::cvss3::Column::AdvisoryId.eq(self.advisory_vulnerability.advisory_id)) @@ -395,18 +396,21 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { .eq(self.advisory_vulnerability.vulnerability_id.clone()), ) .filter(entity::cvss3::Column::MinorVersion.eq(minor_version as i32)) - .one(&self.advisory.graph.connection(&tx)) + .one(connection) .await? .map(|cvss| cvss.into())) } - #[instrument(skip(self, tx), err)] - pub async fn ingest_cvss3_score>( + #[instrument(skip(self, connection), err)] + pub async fn ingest_cvss3_score( &self, cvss3: Cvss3Base, - tx: TX, + connection: &C, ) -> Result { - if let Some(found) = self.get_cvss3_score(cvss3.minor_version, &tx).await? { + if let Some(found) = self + .get_cvss3_score(cvss3.minor_version, connection) + .await? + { return Ok(found); } @@ -426,10 +430,7 @@ impl<'g> AdvisoryVulnerabilityContext<'g> { severity: Set(Severity::from(cvss3.score().roundup().severity())), }; - Ok(model - .insert(&self.advisory.graph.connection(&tx)) - .await? - .into()) + Ok(model.insert(connection).await?.into()) } } @@ -441,7 +442,6 @@ mod test { use crate::graph::Graph; use test_context::test_context; use test_log::test; - use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_test_context::TrustifyContext; @@ -450,8 +450,7 @@ mod test { async fn advisory_affected_vulnerability_assertions( ctx: TrustifyContext, ) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let advisory = system .ingest_advisory( @@ -459,12 +458,12 @@ mod test { ("source", "http://db.com/rhsa-ghsa-2"), &Digests::digest("RHSA-GHSA-1"), (), - Transactional::None, + &ctx.db, ) .await?; let advisory_vulnerability = advisory - .link_to_vulnerability("CVE-42", None, Transactional::None) + .link_to_vulnerability("CVE-42", None, &ctx.db) .await?; advisory_vulnerability @@ -479,7 +478,7 @@ mod test { Version::Exclusive("1.2.0".to_string()), ), }, - Transactional::None, + &ctx.db, ) .await?; @@ -492,7 +491,7 @@ mod test { scheme: VersionScheme::Semver, spec: VersionSpec::Exact("1.1.9".to_string()), }, - Transactional::None, + &ctx.db, ) .await?; @@ -513,8 +512,7 @@ mod test { async fn advisory_not_affected_vulnerability_assertions( ctx: TrustifyContext, ) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let advisory = system .ingest_advisory( @@ -522,12 +520,12 @@ mod test { ("source", "http://db.com/rhsa-ghsa-2"), &Digests::digest("RHSA-GHSA-1"), (), - Transactional::None, + &ctx.db, ) .await?; let advisory_vulnerability = advisory - .link_to_vulnerability("INTERAL-77", None, Transactional::None) + .link_to_vulnerability("INTERAL-77", None, &ctx.db) .await?; advisory_vulnerability @@ -542,7 +540,7 @@ mod test { Version::Exclusive("1.2.0".to_string()), ), }, - Transactional::None, + &ctx.db, ) .await?; @@ -555,7 +553,7 @@ mod test { scheme: VersionScheme::Semver, spec: VersionSpec::Exact("1.1.9".to_string()), }, - Transactional::None, + &ctx.db, ) .await?; diff --git a/modules/ingestor/src/graph/advisory/mod.rs b/modules/ingestor/src/graph/advisory/mod.rs index 287f30a6f..69cff897a 100644 --- a/modules/ingestor/src/graph/advisory/mod.rs +++ b/modules/ingestor/src/graph/advisory/mod.rs @@ -6,18 +6,15 @@ use crate::{ }; use hex::ToHex; use sea_orm::{ - ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, IntoActiveModel, ModelTrait, - QueryFilter, QuerySelect, RelationTrait, + ActiveModelTrait, ActiveValue::Set, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel, + ModelTrait, QueryFilter, QuerySelect, RelationTrait, }; use sea_query::{Condition, JoinType, OnConflict}; use semver::Version; use std::fmt::{Debug, Formatter}; use time::OffsetDateTime; use tracing::instrument; -use trustify_common::{ - db::{Transactional, UpdateDeprecatedAdvisory}, - hashing::Digests, -}; +use trustify_common::{db::UpdateDeprecatedAdvisory, hashing::Digests}; use trustify_entity::{self as entity, advisory, labels::Labels, source_document}; use uuid::Uuid; @@ -62,22 +59,22 @@ impl From<()> for AdvisoryInformation { } impl Graph { - pub async fn get_advisory_by_id>( + pub async fn get_advisory_by_id( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { - Ok(entity::advisory::Entity::find_by_id(id) - .one(&self.connection(&tx)) + Ok(advisory::Entity::find_by_id(id) + .one(connection) .await? .map(|advisory| AdvisoryContext::new(self, advisory))) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_advisory_by_digest>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_advisory_by_digest( &self, digest: &str, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(advisory::Entity::find() .join(JoinType::Join, advisory::Relation::SourceDocument.def()) @@ -87,33 +84,33 @@ impl Graph { .add(source_document::Column::Sha384.eq(digest.to_string())) .add(source_document::Column::Sha512.eq(digest.to_string())), ) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|advisory| AdvisoryContext::new(self, advisory))) } - pub async fn get_advisories>( + pub async fn get_advisories( &self, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(advisory::Entity::find() .with_deprecation(deprecation) - .all(&self.db.connection(&tx)) + .all(connection) .await? .into_iter() .map(|advisory| AdvisoryContext::new(self, advisory)) .collect()) } - #[instrument(skip(self, labels, information, tx), err(level=tracing::Level::INFO))] - pub async fn ingest_advisory>( + #[instrument(skip(self, labels, information, connection), err(level=tracing::Level::INFO))] + pub async fn ingest_advisory( &self, identifier: impl Into + Debug, labels: impl Into, digests: &Digests, information: impl Into, - tx: TX, + connection: &C, ) -> Result { let identifier = identifier.into(); let labels = labels.into(); @@ -128,13 +125,13 @@ impl Graph { version, } = information.into(); - if let Some(found) = self.get_advisory_by_digest(&sha256, &tx).await? { + if let Some(found) = self.get_advisory_by_digest(&sha256, connection).await? { // we already have the exact same document. return Ok(found); } let organization = if let Some(issuer) = issuer { - Some(self.ingest_organization(issuer, (), &tx).await?) + Some(self.ingest_organization(issuer, (), connection).await?) } else { None }; @@ -147,7 +144,7 @@ impl Graph { size: Set(digests.size as i64), }; - let doc = doc_model.insert(&self.connection(&tx)).await?; + let doc = doc_model.insert(connection).await?; // insert @@ -167,13 +164,11 @@ impl Graph { source_document_id: Set(Some(doc.id)), }; - let db = self.connection(&tx); - - let result = model.insert(&db).await?; + let result = model.insert(connection).await?; // update deprecation marker - UpdateDeprecatedAdvisory::execute(&db, &result.identifier).await?; + UpdateDeprecatedAdvisory::execute(connection, &result.identifier).await?; // done @@ -204,14 +199,14 @@ impl<'g> AdvisoryContext<'g> { Self { graph, advisory } } - pub async fn set_published_at>( + pub async fn set_published_at( &self, published_at: OffsetDateTime, - tx: TX, + connection: &C, ) -> Result<(), Error> { let mut entity = self.advisory.clone().into_active_model(); entity.published = Set(Some(published_at)); - entity.save(&self.graph.connection(&tx)).await?; + entity.save(connection).await?; Ok(()) } @@ -219,14 +214,14 @@ impl<'g> AdvisoryContext<'g> { self.advisory.published } - pub async fn set_modified_at>( + pub async fn set_modified_at( &self, modified_at: OffsetDateTime, - tx: TX, + connection: &C, ) -> Result<(), Error> { let mut entity = self.advisory.clone().into_active_model(); entity.modified = Set(Some(modified_at)); - entity.save(&self.graph.connection(&tx)).await?; + entity.save(connection).await?; Ok(()) } @@ -234,14 +229,14 @@ impl<'g> AdvisoryContext<'g> { self.advisory.modified } - pub async fn set_withdrawn_at>( + pub async fn set_withdrawn_at( &self, withdrawn_at: OffsetDateTime, - tx: TX, + connection: &C, ) -> Result<(), Error> { let mut entity = self.advisory.clone().into_active_model(); entity.withdrawn = Set(Some(withdrawn_at)); - entity.save(&self.graph.connection(&tx)).await?; + entity.save(connection).await?; Ok(()) } @@ -249,27 +244,27 @@ impl<'g> AdvisoryContext<'g> { self.advisory.withdrawn } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_vulnerability>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_vulnerability( &self, identifier: &str, - tx: TX, + connection: &C, ) -> Result>, Error> { Ok(self .advisory .find_related(entity::advisory_vulnerability::Entity) .filter(entity::advisory_vulnerability::Column::VulnerabilityId.eq(identifier)) - .one(&self.graph.connection(&tx)) + .one(connection) .await? .map(|vuln| (self, vuln).into())) } - #[instrument(skip(self, information, tx), err)] - pub async fn link_to_vulnerability>( + #[instrument(skip(self, information, connection), err)] + pub async fn link_to_vulnerability( &self, identifier: &str, information: Option, - tx: TX, + connection: &C, ) -> Result { let entity = entity::advisory_vulnerability::ActiveModel { advisory_id: Set(self.advisory.id), @@ -302,20 +297,20 @@ impl<'g> AdvisoryContext<'g> { ]) .to_owned(), ) - .exec_with_returning(&self.graph.connection(&tx)) + .exec_with_returning(connection) .await?; Ok((self, entity).into()) } - pub async fn vulnerabilities>( + pub async fn vulnerabilities( &self, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(self .advisory .find_related(entity::advisory_vulnerability::Entity) - .all(&self.graph.connection(&tx)) + .all(connection) .await? .into_iter() .map(|e| (self, e).into()) @@ -332,7 +327,6 @@ mod test { use test_log::test; use time::macros::datetime; use time::OffsetDateTime; - use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_entity::labels::Labels; use trustify_test_context::TrustifyContext; @@ -340,8 +334,7 @@ mod test { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_advisories(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let advisory1 = system .ingest_advisory( @@ -349,7 +342,7 @@ mod test { Labels::from_one("source", "http://db.com/rhsa-ghsa-2"), &Digests::digest("RHSA-GHSA-1_1"), (), - Transactional::None, + &ctx.db, ) .await?; @@ -359,7 +352,7 @@ mod test { Labels::from_one("source", "http://db.com/rhsa-ghsa-2"), &Digests::digest("RHSA-GHSA-1_1"), (), - Transactional::None, + &ctx.db, ) .await?; @@ -369,7 +362,7 @@ mod test { Labels::from_one("source", "http://db.com/rhsa-ghsa-2"), &Digests::digest("RHSA-GHSA-1_2"), (), - Transactional::None, + &ctx.db, ) .await?; @@ -382,8 +375,7 @@ mod test { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_advisory_cve(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let advisory = system .ingest_advisory( @@ -391,21 +383,21 @@ mod test { Labels::from_one("source", "http://db.com/rhsa-ghsa-2"), &Digests::digest("RHSA-GHSA-1"), (), - Transactional::None, + &ctx.db, ) .await?; advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory - .link_to_vulnerability("CVE-123", None, Transactional::None) + .link_to_vulnerability("CVE-123", None, &ctx.db) .await?; advisory - .link_to_vulnerability("CVE-456", None, Transactional::None) + .link_to_vulnerability("CVE-456", None, &ctx.db) .await?; - let vulns = advisory.vulnerabilities(()).await?; + let vulns = advisory.vulnerabilities(&ctx.db).await?; assert_eq!(vulns.len(), 2); @@ -431,8 +423,7 @@ mod test { } } - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let a1 = system .ingest_advisory( @@ -440,7 +431,7 @@ mod test { (), &Digests::digest("RHSA-1"), Info("RHSA", datetime!(2024-01-02 00:00:00 UTC)), - Transactional::None, + &ctx.db, ) .await? .advisory @@ -452,7 +443,7 @@ mod test { (), &Digests::digest("RHSA-2"), Info("RHSA", datetime!(2024-01-03 00:00:00 UTC)), - Transactional::None, + &ctx.db, ) .await? .advisory @@ -464,13 +455,15 @@ mod test { (), &Digests::digest("RHSA-3"), Info("RHSA", datetime!(2024-01-01 00:00:00 UTC)), - Transactional::None, + &ctx.db, ) .await? .advisory .id; - let mut advs = system.get_advisories(Deprecation::Consider, ()).await?; + let mut advs = system + .get_advisories(Deprecation::Consider, &ctx.db) + .await?; advs.sort_unstable_by(|a, b| a.advisory.modified.cmp(&b.advisory.modified)); let deps = advs .iter() diff --git a/modules/ingestor/src/graph/cpe.rs b/modules/ingestor/src/graph/cpe.rs index 5dac90989..97ee80937 100644 --- a/modules/ingestor/src/graph/cpe.rs +++ b/modules/ingestor/src/graph/cpe.rs @@ -6,59 +6,56 @@ use std::{ fmt::{Debug, Formatter}, }; use tracing::instrument; -use trustify_common::{ - cpe::Cpe, - db::{chunk::EntityChunkedIter, Transactional}, -}; +use trustify_common::{cpe::Cpe, db::chunk::EntityChunkedIter}; use trustify_entity::cpe; use uuid::Uuid; impl Graph { - pub async fn get_cpe, TX: AsRef>( + pub async fn get_cpe( &self, - cpe: C, - tx: TX, + cpe: impl Into, + connection: &C, ) -> Result, Error> { let cpe = cpe.into(); let query = cpe::Entity::find_by_id(cpe.uuid()); - if let Some(found) = query.one(&self.connection(&tx)).await? { + if let Some(found) = query.one(connection).await? { Ok(Some((self, found).into())) } else { Ok(None) } } - pub async fn get_cpe_by_query>( + pub async fn get_cpe_by_query( &self, query: SelectStatement, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(cpe::Entity::find() .filter(cpe::Column::Id.in_subquery(query)) - .all(&self.connection(&tx)) + .all(connection) .await? .into_iter() .map(|cpe22| (self, cpe22).into()) .collect()) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn ingest_cpe22(&self, cpe: C, tx: TX) -> Result - where - C: Into + Debug, - TX: AsRef, - { + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn ingest_cpe22( + &self, + cpe: impl Into + Debug, + connection: &C, + ) -> Result { let cpe = cpe.into(); - if let Some(found) = self.get_cpe(cpe.clone(), &tx).await? { + if let Some(found) = self.get_cpe(cpe.clone(), connection).await? { return Ok(found); } let entity: cpe::ActiveModel = cpe.into(); - Ok((self, entity.insert(&self.connection(&tx)).await?).into()) + Ok((self, entity.insert(connection).await?).into()) } } @@ -129,13 +126,12 @@ mod test { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_cpe(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let graph = Graph::new(db); + let graph = Graph::new(ctx.db.clone()); let cpe = Cpe::from_str("cpe:/a:redhat:enterprise_linux:9::crb")?; - let c1 = graph.ingest_cpe22(cpe.clone(), ()).await?; - let c2 = graph.ingest_cpe22(cpe, ()).await?; + let c1 = graph.ingest_cpe22(cpe.clone(), &ctx.db).await?; + let c2 = graph.ingest_cpe22(cpe, &ctx.db).await?; assert_eq!(c1.cpe.id, c2.cpe.id); diff --git a/modules/ingestor/src/graph/mod.rs b/modules/ingestor/src/graph/mod.rs index 21f5eb7b3..40edc18f4 100644 --- a/modules/ingestor/src/graph/mod.rs +++ b/modules/ingestor/src/graph/mod.rs @@ -7,10 +7,8 @@ pub mod purl; pub mod sbom; pub mod vulnerability; -use sea_orm::{DbErr, TransactionTrait}; +use sea_orm::DbErr; use std::fmt::Debug; -use tracing::instrument; -use trustify_common::db::{ConnectionOrTransaction, Transactional}; #[derive(Debug, Clone)] pub struct Graph { @@ -29,34 +27,4 @@ impl Graph { pub fn new(db: trustify_common::db::Database) -> Self { Self { db } } - - /// Create a `Transactional::Some(_)` with a new transaction. - /// - /// The transaction will be rolled-back unless explicitly `commit()`'d before - /// it drops. - #[instrument] - pub async fn transaction(&self) -> Result { - Ok(Transactional::Some(self.db.begin().await?)) - } - - pub fn connection<'db, TX: AsRef>( - &'db self, - tx: &'db TX, - ) -> ConnectionOrTransaction { - match tx.as_ref() { - Transactional::None => ConnectionOrTransaction::Connection(&self.db), - Transactional::Some(tx) => ConnectionOrTransaction::Transaction(tx), - } - } - - pub async fn close(self) -> anyhow::Result<()> { - self.db.close().await - } - - /// Ping the database. - /// - /// Intended to be used for health checks. - pub async fn ping(&self) -> anyhow::Result<()> { - self.db.ping().await - } } diff --git a/modules/ingestor/src/graph/organization.rs b/modules/ingestor/src/graph/organization.rs index 2c7b6d6e5..0c7c2dfac 100644 --- a/modules/ingestor/src/graph/organization.rs +++ b/modules/ingestor/src/graph/organization.rs @@ -1,7 +1,6 @@ -use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set}; +use sea_orm::{ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, Set}; use std::fmt::Debug; use tracing::instrument; -use trustify_common::db::Transactional; use trustify_entity::organization; use crate::graph::{error::Error, Graph}; @@ -42,48 +41,48 @@ impl<'g> OrganizationContext<'g> { } impl Graph { - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_organizations( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_organizations( &self, - tx: impl AsRef, + connection: &C, ) -> Result, Error> { Ok(organization::Entity::find() - .all(&self.connection(&tx)) + .all(connection) .await? .into_iter() .map(|organization| OrganizationContext::new(self, organization)) .collect()) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_organization_by_name>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_organization_by_name( &self, name: impl Into + Debug, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(organization::Entity::find() .filter(organization::Column::Name.eq(name.into())) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|organization| OrganizationContext::new(self, organization))) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn ingest_organization>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn ingest_organization( &self, name: impl Into + Debug, information: impl Into + Debug, - tx: TX, + connection: &C, ) -> Result { let name = name.into(); let information = information.into(); - if let Some(found) = self.get_organization_by_name(&name, &tx).await? { + if let Some(found) = self.get_organization_by_name(&name, connection).await? { if information.has_data() { let mut entity = organization::ActiveModel::from(found.organization); entity.website = Set(information.website); entity.cpe_key = Set(information.cpe_key); - let model = entity.update(&self.connection(&tx)).await?; + let model = entity.update(connection).await?; Ok(OrganizationContext::new(found.graph, model)) } else { Ok(found) @@ -98,7 +97,7 @@ impl Graph { Ok(OrganizationContext::new( self, - entity.insert(&self.connection(&tx)).await?, + entity.insert(connection).await?, )) } } diff --git a/modules/ingestor/src/graph/product/mod.rs b/modules/ingestor/src/graph/product/mod.rs index 3ca40e3fd..21ccfb932 100644 --- a/modules/ingestor/src/graph/product/mod.rs +++ b/modules/ingestor/src/graph/product/mod.rs @@ -1,10 +1,12 @@ pub mod product_version; use entity::organization; -use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, Set}; +use sea_orm::{ + ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, ModelTrait, QueryFilter, Set, +}; use std::fmt::Debug; use tracing::instrument; -use trustify_common::{cpe::Cpe, db::Transactional}; +use trustify_common::cpe::Cpe; use trustify_entity as entity; use trustify_entity::product; use uuid::Uuid; @@ -26,19 +28,19 @@ impl<'g> ProductContext<'g> { Self { graph, product } } - pub async fn ingest_product_version>( + pub async fn ingest_product_version( &self, version: String, sbom_id: Option, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(found) = self.get_version(version.clone(), &tx).await? { + if let Some(found) = self.get_version(version.clone(), connection).await? { let product_version = ProductVersionContext::new(self, found.product_version.clone()); if let Some(id) = sbom_id { // If sbom is not yet set, link to the SBOM and update the context if found.product_version.sbom_id.is_none() { - Ok(product_version.link_to_sbom(id, &tx).await?) + Ok(product_version.link_to_sbom(id, connection).await?) } else { Ok(product_version) } @@ -53,25 +55,23 @@ impl<'g> ProductContext<'g> { version: Set(version.clone()), }; - let product_version = - ProductVersionContext::new(self, model.insert(&self.graph.connection(&tx)).await?); + let product_version = ProductVersionContext::new(self, model.insert(connection).await?); // If there's an sbom_id, link to the SBOM and update the context if let Some(id) = sbom_id { - Ok(product_version.link_to_sbom(id, &tx).await?) + Ok(product_version.link_to_sbom(id, connection).await?) } else { Ok(product_version) } } } - pub async fn ingest_product_version_range>( + pub async fn ingest_product_version_range( &self, info: VersionInfo, cpe_key: Option, - tx: TX, + connection: &C, ) -> Result { - let connection = &self.graph.connection(&tx); let version_range = info.into_active_model(); let version_range = version_range.insert(connection).await?; @@ -85,14 +85,14 @@ impl<'g> ProductContext<'g> { Ok(model.insert(connection).await?) } - pub async fn get_vendor>( + pub async fn get_vendor( &self, - tx: TX, + connection: &C, ) -> Result, Error> { match self .product .find_related(organization::Entity) - .one(&self.graph.connection(&tx)) + .one(connection) .await? { Some(org) => Ok(Some(OrganizationContext::new(self.graph, org))), @@ -100,16 +100,16 @@ impl<'g> ProductContext<'g> { } } - pub async fn get_version>( + pub async fn get_version( &self, version: String, - tx: TX, + connection: &C, ) -> Result, Error> { match self .product .find_related(entity::product_version::Entity) .filter(entity::product_version::Column::Version.eq(version)) - .one(&self.graph.connection(&tx)) + .one(connection) .await? { Some(ver) => Ok(Some(ProductVersionContext::new(self, ver))), @@ -137,12 +137,12 @@ impl From<()> for ProductInformation { } impl Graph { - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn ingest_product>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn ingest_product( &self, name: impl Into + Debug, information: impl Into + Debug, - tx: TX, + connection: &C, ) -> Result { let name = name.into(); let information = information.into(); @@ -153,7 +153,7 @@ impl Graph { let entity = if let Some(vendor) = information.vendor { if let Some(found) = self - .get_product_by_organization(vendor.clone(), &name, &tx) + .get_product_by_organization(vendor.clone(), &name, connection) .await? { return Ok(found); @@ -166,7 +166,7 @@ impl Graph { cpe_key: organization_cpe_key, website: None, }; - let org = self.ingest_organization(vendor, org, &tx).await?; + let org = self.ingest_organization(vendor, org, connection).await?; product::ActiveModel { id: Default::default(), @@ -184,51 +184,48 @@ impl Graph { } }; - Ok(ProductContext::new( - self, - entity.insert(&self.connection(&tx)).await?, - )) + Ok(ProductContext::new(self, entity.insert(connection).await?)) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] pub async fn get_products( &self, - tx: impl AsRef, + connection: &impl ConnectionTrait, ) -> Result, Error> { Ok(product::Entity::find() - .all(&self.connection(&tx)) + .all(connection) .await? .into_iter() .map(|product| ProductContext::new(self, product)) .collect()) } - #[instrument(skip(self, tx), err)] - pub async fn get_product_by_name>( + #[instrument(skip(self, connection), err)] + pub async fn get_product_by_name( &self, name: impl Into + Debug, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(product::Entity::find() .filter(product::Column::Name.eq(name.into())) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|product| ProductContext::new(self, product))) } - #[instrument(skip(self, tx), err)] - pub async fn get_product_by_organization>( + #[instrument(skip(self, connection), err)] + pub async fn get_product_by_organization( &self, org: impl Into + Debug, name: impl Into + Debug, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(found) = self.get_organization_by_name(org, &tx).await? { + if let Some(found) = self.get_organization_by_name(org, connection).await? { Ok(found .organization .find_related(product::Entity) .filter(product::Column::Name.eq(name.into())) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|product| ProductContext::new(self, product))) } else { diff --git a/modules/ingestor/src/graph/product/product_version.rs b/modules/ingestor/src/graph/product/product_version.rs index 98216329b..54502c434 100644 --- a/modules/ingestor/src/graph/product/product_version.rs +++ b/modules/ingestor/src/graph/product/product_version.rs @@ -2,8 +2,7 @@ use std::fmt::{Debug, Formatter}; use crate::graph::{error::Error, sbom::SbomContext}; use entity::{product_version, sbom}; -use sea_orm::{ActiveModelTrait, EntityTrait, Set}; -use trustify_common::db::Transactional; +use sea_orm::{ActiveModelTrait, ConnectionTrait, EntityTrait, Set}; use trustify_entity as entity; use uuid::Uuid; @@ -33,29 +32,27 @@ impl<'g> ProductVersionContext<'g> { } } - pub async fn link_to_sbom>( + pub async fn link_to_sbom( mut self, sbom_id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { let mut product_version: product_version::ActiveModel = self.product_version.clone().into(); product_version.sbom_id = Set(Some(sbom_id)); - let ver = product_version - .update(&self.product.graph.connection(&tx)) - .await?; + let ver = product_version.update(connection).await?; self.product_version.sbom_id = ver.sbom_id; Ok(self) } - pub async fn get_sbom>( + pub async fn get_sbom( &self, - tx: TX, + connection: &C, ) -> Result, Error> { match self.product_version.sbom_id { Some(sbom_id) => Ok(sbom::Entity::find_by_id(sbom_id) - .one(&self.product.graph.connection(&tx)) + .one(connection) .await? .map(|sbom| SbomContext::new(self.product.graph, sbom))), None => Ok(None), diff --git a/modules/ingestor/src/graph/purl/mod.rs b/modules/ingestor/src/graph/purl/mod.rs index e37c96839..fb8c2252f 100644 --- a/modules/ingestor/src/graph/purl/mod.rs +++ b/modules/ingestor/src/graph/purl/mod.rs @@ -7,12 +7,14 @@ pub mod qualified_package; use crate::graph::{error::Error, Graph}; use package_version::PackageVersionContext; use qualified_package::QualifiedPackageContext; -use sea_orm::{prelude::Uuid, ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set}; +use sea_orm::{ + prelude::Uuid, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, Set, +}; use sea_query::SelectStatement; use std::fmt::{Debug, Formatter}; use tracing::instrument; use trustify_common::{ - db::{limiter::LimiterTrait, Transactional}, + db::limiter::LimiterTrait, model::{Paginated, PaginatedResults}, purl::{Purl, PurlErr}, }; @@ -25,42 +27,44 @@ impl Graph { /// /// The `pkg` parameter does not necessarily require the presence of qualifiers, but /// is assumed to be *complete*. - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn ingest_qualified_package>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn ingest_qualified_package( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result { - let package = self.ingest_package(purl, &tx).await?; - let package_version = package.ingest_package_version(purl, &tx).await?; - package_version.ingest_qualified_package(purl, &tx).await + let package = self.ingest_package(purl, connection).await?; + let package_version = package.ingest_package_version(purl, connection).await?; + package_version + .ingest_qualified_package(purl, connection) + .await } /// Ensure the fetch knows about and contains a record for a *versioned* package. /// /// This method will ensure the package being referenced is also ingested. - pub async fn ingest_package_version>( + pub async fn ingest_package_version( &self, pkg: &Purl, - tx: TX, + connection: &C, ) -> Result { - if let Some(found) = self.get_package_version(pkg, &tx).await? { + if let Some(found) = self.get_package_version(pkg, connection).await? { return Ok(found); } - let package = self.ingest_package(pkg, &tx).await?; + let package = self.ingest_package(pkg, connection).await?; - package.ingest_package_version(pkg, &tx).await + package.ingest_package_version(pkg, connection).await } /// Ensure the fetch knows about and contains a record for a *versionless* package. /// /// This method will ensure the package being referenced is also ingested. - pub async fn ingest_package>( + pub async fn ingest_package( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result { - if let Some(found) = self.get_package(purl, &tx).await? { + if let Some(found) = self.get_package(purl, connection).await? { Ok(found) } else { let model = entity::base_purl::ActiveModel { @@ -70,40 +74,39 @@ impl Graph { name: Set(purl.name.clone()), }; - Ok(PackageContext::new( - self, - model.insert(&self.connection(&tx)).await?, - )) + Ok(PackageContext::new(self, model.insert(connection).await?)) } } /// Retrieve a *fully-qualified* package entry, if it exists. /// /// Non-mutating to the fetch. - pub async fn get_qualified_package>( + pub async fn get_qualified_package( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(package_version) = self.get_package_version(purl, &tx).await? { - package_version.get_qualified_package(purl, &tx).await + if let Some(package_version) = self.get_package_version(purl, connection).await? { + package_version + .get_qualified_package(purl, connection) + .await } else { Ok(None) } } - pub async fn get_qualified_package_by_id>( + pub async fn get_qualified_package_by_id( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { let found = entity::qualified_purl::Entity::find_by_id(id) - .one(&self.connection(&tx)) + .one(connection) .await?; if let Some(qualified_package) = found { if let Some(package_version) = self - .get_package_version_by_id(qualified_package.versioned_purl_id, tx) + .get_package_version_by_id(qualified_package.versioned_purl_id, connection) .await? { Ok(Some(QualifiedPackageContext::new( @@ -118,22 +121,22 @@ impl Graph { } } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_qualified_packages_by_query>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_qualified_packages_by_query( &self, query: SelectStatement, - tx: TX, + connection: &C, ) -> Result, Error> { let found = entity::qualified_purl::Entity::find() .filter(entity::qualified_purl::Column::Id.in_subquery(query)) - .all(&self.connection(&tx)) + .all(connection) .await?; let mut package_versions = Vec::new(); for base in &found { if let Some(package_version) = self - .get_package_version_by_id(base.versioned_purl_id, &tx) + .get_package_version_by_id(base.versioned_purl_id, connection) .await? { let qualified_package = @@ -148,30 +151,30 @@ impl Graph { /// Retrieve a *versioned* package entry, if it exists. /// /// Non-mutating to the fetch. - pub async fn get_package_version>( + pub async fn get_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result>, Error> { - if let Some(pkg) = self.get_package(purl, &tx).await? { - pkg.get_package_version(purl, &tx).await + if let Some(pkg) = self.get_package(purl, connection).await? { + pkg.get_package_version(purl, connection).await } else { Ok(None) } } - #[instrument(skip(self, tx), err)] - pub async fn get_package_version_by_id>( + #[instrument(skip(self, connection), err)] + pub async fn get_package_version_by_id( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { if let Some(package_version) = entity::versioned_purl::Entity::find_by_id(id) - .one(&self.connection(&tx)) + .one(connection) .await? { if let Some(package) = self - .get_package_by_id(package_version.base_purl_id, &tx) + .get_package_by_id(package_version.base_purl_id, connection) .await? { Ok(Some(PackageVersionContext::new(&package, package_version))) @@ -186,10 +189,10 @@ impl Graph { /// Retrieve a *versionless* package entry, if it exists. /// /// Non-mutating to the fetch. - pub async fn get_package>( + pub async fn get_package( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(entity::base_purl::Entity::find() .filter(entity::base_purl::Column::Type.eq(&purl.ty)) @@ -199,19 +202,19 @@ impl Graph { entity::base_purl::Column::Namespace.is_null() }) .filter(entity::base_purl::Column::Name.eq(&purl.name)) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|package| PackageContext::new(self, package))) } - #[instrument(skip(self, tx), err)] - pub async fn get_package_by_id>( + #[instrument(skip(self, connection), err)] + pub async fn get_package_by_id( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { if let Some(found) = entity::base_purl::Entity::find_by_id(id) - .one(&self.connection(&tx)) + .one(connection) .await? { Ok(Some(PackageContext::new(self, found))) @@ -243,13 +246,13 @@ impl<'g> PackageContext<'g> { } /// Ensure the fetch knows about and contains a record for a *version* of this package. - pub async fn ingest_package_version>( + pub async fn ingest_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { if let Some(version) = &purl.version { - if let Some(found) = self.get_package_version(purl, &tx).await? { + if let Some(found) = self.get_package_version(purl, connection).await? { Ok(found) } else { let model = entity::versioned_purl::ActiveModel { @@ -260,7 +263,7 @@ impl<'g> PackageContext<'g> { Ok(PackageVersionContext::new( self, - model.insert(&self.graph.connection(&tx)).await?, + model.insert(connection).await?, )) } } else { @@ -271,15 +274,15 @@ impl<'g> PackageContext<'g> { /// Retrieve a *version* package entry for this package, if it exists. /// /// Non-mutating to the fetch. - pub async fn get_package_version>( + pub async fn get_package_version( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result>, Error> { Ok(entity::versioned_purl::Entity::find() .filter(entity::versioned_purl::Column::BasePurlId.eq(self.base_purl.id)) .filter(entity::versioned_purl::Column::Version.eq(purl.version.clone())) - .one(&self.graph.connection(&tx)) + .one(connection) .await .map(|package_version| { package_version @@ -290,29 +293,27 @@ impl<'g> PackageContext<'g> { /// Retrieve known versions of this package. /// /// Non-mutating to the fetch. - pub async fn get_versions>( + pub async fn get_versions( &self, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(entity::versioned_purl::Entity::find() .filter(entity::versioned_purl::Column::BasePurlId.eq(self.base_purl.id)) - .all(&self.graph.connection(&tx)) + .all(connection) .await? .drain(0..) .map(|each| PackageVersionContext::new(self, each)) .collect()) } - pub async fn get_versions_paginated>( + pub async fn get_versions_paginated( &self, paginated: Paginated, - tx: TX, + connection: &C, ) -> Result, Error> { - let connection = self.graph.connection(&tx); - let limiter = entity::versioned_purl::Entity::find() .filter(entity::versioned_purl::Column::BasePurlId.eq(self.base_purl.id)) - .limiting(&connection, paginated.limit, paginated.offset); + .limiting(connection, paginated.limit, paginated.offset); Ok(PaginatedResults { total: limiter.total().await?, @@ -341,7 +342,6 @@ mod tests { use test_context::test_context; use test_log::test; - use trustify_common::db::Transactional; use trustify_common::model::Paginated; use trustify_common::purl::Purl; use trustify_entity::qualified_purl; @@ -354,28 +354,18 @@ mod tests { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_packages(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let pkg1 = system - .ingest_package( - &"pkg:maven/io.quarkus/quarkus-core".try_into()?, - Transactional::None, - ) + .ingest_package(&"pkg:maven/io.quarkus/quarkus-core".try_into()?, &ctx.db) .await?; let pkg2 = system - .ingest_package( - &"pkg:maven/io.quarkus/quarkus-core".try_into()?, - Transactional::None, - ) + .ingest_package(&"pkg:maven/io.quarkus/quarkus-core".try_into()?, &ctx.db) .await?; let pkg3 = system - .ingest_package( - &"pkg:maven/io.quarkus/quarkus-addons".try_into()?, - Transactional::None, - ) + .ingest_package(&"pkg:maven/io.quarkus/quarkus-addons".try_into()?, &ctx.db) .await?; assert_eq!(pkg1.base_purl.id, pkg2.base_purl.id,); @@ -390,14 +380,10 @@ mod tests { async fn ingest_package_versions_missing_version( ctx: TrustifyContext, ) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let result = system - .ingest_package_version( - &"pkg:maven/io.quarkus/quarkus-addons".try_into()?, - Transactional::None, - ) + .ingest_package_version(&"pkg:maven/io.quarkus/quarkus-addons".try_into()?, &ctx.db) .await; assert!(result.is_err()); @@ -408,27 +394,26 @@ mod tests { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_package_versions(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let pkg1 = system .ingest_package_version( &"pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?, - Transactional::None, + &ctx.db, ) .await?; let pkg2 = system .ingest_package_version( &"pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?, - Transactional::None, + &ctx.db, ) .await?; let pkg3 = system .ingest_package_version( &"pkg:maven/io.quarkus/quarkus-core@4.5.6".try_into()?, - Transactional::None, + &ctx.db, ) .await?; @@ -444,8 +429,7 @@ mod tests { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn get_versions_paginated(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); const TOTAL_ITEMS: u64 = 200; let _page_size = NonZeroU64::new(50).unwrap(); @@ -453,20 +437,15 @@ mod tests { for v in 0..TOTAL_ITEMS { let version = format!("pkg:maven/io.quarkus/quarkus-core@{v}").try_into()?; - let _ = system - .ingest_package_version(&version, Transactional::None) - .await?; + let _ = system.ingest_package_version(&version, &ctx.db).await?; } let pkg = system - .get_package( - &"pkg:maven/io.quarkus/quarkus-core".try_into()?, - Transactional::None, - ) + .get_package(&"pkg:maven/io.quarkus/quarkus-core".try_into()?, &ctx.db) .await? .unwrap(); - let all_versions = pkg.get_versions(Transactional::None).await?; + let all_versions = pkg.get_versions(&ctx.db).await?; assert_eq!(TOTAL_ITEMS, all_versions.len() as u64); @@ -476,7 +455,7 @@ mod tests { offset: 50, limit: 50, }, - Transactional::None, + &ctx.db, ) .await?; @@ -489,7 +468,7 @@ mod tests { offset: 100, limit: 50, }, - Transactional::None, + &ctx.db, ) .await?; @@ -504,24 +483,23 @@ mod tests { async fn ingest_qualified_packages_transactionally( ctx: TrustifyContext, ) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db.clone()); + let system = Graph::new(ctx.db.clone()); let tx_system = system.clone(); - db.transaction(|_tx| { + ctx.db.transaction(|tx| { Box::pin(async move { let pkg1 = tx_system .ingest_qualified_package( &"pkg:oci/ubi9-container@sha256:2f168398c538b287fd705519b83cd5b604dc277ef3d9f479c28a2adb4d830a49?repository_url=registry.redhat.io/ubi9&tag=9.2-755.1697625012".try_into()?, - &Transactional::None, + tx, ) .await?; let pkg2 = tx_system .ingest_qualified_package( &"pkg:oci/ubi9-container@sha256:2f168398c538b287fd705519b83cd5b604dc277ef3d9f479c28a2adb4d830a49?repository_url=registry.redhat.io/ubi9&tag=9.2-755.1697625012".try_into()?, - &Transactional::None, + tx, ) .await?; @@ -537,34 +515,33 @@ mod tests { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_qualified_packages(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); let pkg1 = system .ingest_qualified_package( &"pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?, - &Transactional::None, + &&ctx.db, ) .await?; let pkg2 = system .ingest_qualified_package( &"pkg:maven/io.quarkus/quarkus-core@1.2.3".try_into()?, - &Transactional::None, + &&ctx.db, ) .await?; let pkg3 = system .ingest_qualified_package( &"pkg:maven/io.quarkus/quarkus-core@1.2.3?type=jar".try_into()?, - &Transactional::None, + &&ctx.db, ) .await?; let pkg4 = system .ingest_qualified_package( &"pkg:maven/io.quarkus/quarkus-core@1.2.3?type=jar".try_into()?, - &Transactional::None, + &&ctx.db, ) .await?; @@ -588,8 +565,7 @@ mod tests { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn query_qualified_packages(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let graph = Graph::new(db); + let graph = Graph::new(ctx.db.clone()); for i in [ "pkg:maven/io.quarkus/quarkus-core@1.2.3", @@ -597,7 +573,7 @@ mod tests { "pkg:maven/io.quarkus/quarkus-core@1.2.3?type=pom", ] { graph - .ingest_qualified_package(&i.try_into()?, &Transactional::None) + .ingest_qualified_package(&i.try_into()?, &&ctx.db) .await?; } @@ -615,7 +591,7 @@ mod tests { )) .into_query(); let result = graph - .get_qualified_packages_by_query(select, Transactional::None) + .get_qualified_packages_by_query(select, &ctx.db) .await?; log::debug!("{result:?}"); diff --git a/modules/ingestor/src/graph/purl/package_version.rs b/modules/ingestor/src/graph/purl/package_version.rs index 20c28486c..b7035c100 100644 --- a/modules/ingestor/src/graph/purl/package_version.rs +++ b/modules/ingestor/src/graph/purl/package_version.rs @@ -4,9 +4,9 @@ use crate::graph::{ error::Error, purl::{qualified_package::QualifiedPackageContext, PackageContext}, }; -use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set}; +use sea_orm::{ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter, Set}; use std::fmt::{Debug, Formatter}; -use trustify_common::{db::Transactional, purl::Purl}; +use trustify_common::purl::Purl; use trustify_entity::{self as entity, qualified_purl::Qualifiers, versioned_purl}; /// Live context for a package version. @@ -30,12 +30,12 @@ impl<'g> PackageVersionContext<'g> { } } - pub async fn ingest_qualified_package>( + pub async fn ingest_qualified_package( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(found) = self.get_qualified_package(purl, &tx).await? { + if let Some(found) = self.get_qualified_package(purl, connection).await? { return Ok(found); } let cp = purl.clone().into(); @@ -47,24 +47,22 @@ impl<'g> PackageVersionContext<'g> { purl: Set(cp), }; - let qualified_package = qualified_package - .insert(&self.package.graph.connection(&tx)) - .await?; + let qualified_package = qualified_package.insert(connection).await?; Ok(QualifiedPackageContext::new(self, qualified_package)) } - pub async fn get_qualified_package>( + pub async fn get_qualified_package( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result>, Error> { let found = entity::qualified_purl::Entity::find() .filter(entity::qualified_purl::Column::VersionedPurlId.eq(self.package_version.id)) .filter( entity::qualified_purl::Column::Qualifiers.eq(Qualifiers(purl.qualifiers.clone())), ) - .one(&self.package.graph.connection(&tx)) + .one(connection) .await?; Ok(found.map(|model| QualifiedPackageContext::new(self, model))) @@ -73,14 +71,14 @@ impl<'g> PackageVersionContext<'g> { /// Retrieve known variants of this package version. /// /// Non-mutating to the fetch. - pub async fn get_variants>( + pub async fn get_variants( &self, _pkg: Purl, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(entity::qualified_purl::Entity::find() .filter(entity::qualified_purl::Column::VersionedPurlId.eq(self.package_version.id)) - .all(&self.package.graph.connection(&tx)) + .all(connection) .await? .into_iter() .map(|base| QualifiedPackageContext::new(self, base)) diff --git a/modules/ingestor/src/graph/purl/qualified_package.rs b/modules/ingestor/src/graph/purl/qualified_package.rs index 37bb0abd8..e3031b9ae 100644 --- a/modules/ingestor/src/graph/purl/qualified_package.rs +++ b/modules/ingestor/src/graph/purl/qualified_package.rs @@ -1,15 +1,14 @@ //! Support for a *fully-qualified* package. -use crate::graph::error::Error; -use crate::graph::purl::package_version::PackageVersionContext; -use crate::graph::sbom::SbomContext; -use std::collections::BTreeMap; -use std::fmt::{Debug, Formatter}; -use std::hash::{Hash, Hasher}; -use trustify_common::db::Transactional; +use crate::graph::{error::Error, purl::package_version::PackageVersionContext, sbom::SbomContext}; +use sea_orm::ConnectionTrait; +use std::{ + collections::BTreeMap, + fmt::{Debug, Formatter}, + hash::{Hash, Hasher}, +}; use trustify_common::purl::Purl; -use trustify_entity as entity; -use trustify_entity::qualified_purl; +use trustify_entity::{self as entity, qualified_purl}; #[derive(Clone)] pub struct QualifiedPackageContext<'g> { @@ -59,9 +58,9 @@ impl<'g> QualifiedPackageContext<'g> { qualified_package, } } - pub async fn sboms_containing>( + pub async fn sboms_containing( &self, - _tx: TX, + _connection: &C, ) -> Result, Error> { /* Ok(entity::sbom::Entity::find() diff --git a/modules/ingestor/src/graph/sbom/clearly_defined.rs b/modules/ingestor/src/graph/sbom/clearly_defined.rs index 6f70753b4..9ecde9ee4 100644 --- a/modules/ingestor/src/graph/sbom/clearly_defined.rs +++ b/modules/ingestor/src/graph/sbom/clearly_defined.rs @@ -1,23 +1,20 @@ use crate::graph::purl::creator::PurlCreator; use crate::graph::sbom::{LicenseCreator, LicenseInfo, SbomContext, SbomInformation}; -use sea_orm::{EntityTrait, Set}; +use sea_orm::{ConnectionTrait, EntityTrait, Set}; use sea_query::OnConflict; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use tracing::instrument; -use trustify_common::db::Transactional; use trustify_common::purl::Purl; use trustify_entity::purl_license_assertion; impl SbomContext { - #[instrument(skip(tx, curation), err)] - pub async fn ingest_clearly_defined_curation>( + #[instrument(skip(db, curation), err)] + pub async fn ingest_clearly_defined_curation( &self, curation: Curation, - tx: TX, + db: &C, ) -> Result<(), anyhow::Error> { - let db = &self.graph.db.connection(&tx); - let mut purls = PurlCreator::new(); let mut licenses = LicenseCreator::new(); diff --git a/modules/ingestor/src/graph/sbom/cyclonedx.rs b/modules/ingestor/src/graph/sbom/cyclonedx.rs index 04d1b9e87..c7b85672c 100644 --- a/modules/ingestor/src/graph/sbom/cyclonedx.rs +++ b/modules/ingestor/src/graph/sbom/cyclonedx.rs @@ -5,14 +5,15 @@ use crate::graph::{ purl::creator::PurlCreator, sbom::{PackageCreator, PackageReference, RelationshipCreator, SbomContext, SbomInformation}, }; -use cyclonedx_bom::models::license::{LicenseChoice, LicenseIdentifier}; -use cyclonedx_bom::prelude::{Bom, Component, Components}; +use cyclonedx_bom::{ + models::license::{LicenseChoice, LicenseIdentifier}, + prelude::{Bom, Component, Components}, +}; use sea_orm::ConnectionTrait; -use std::collections::HashMap; -use std::str::FromStr; +use std::{collections::HashMap, str::FromStr}; use time::{format_description::well_known::Iso8601, OffsetDateTime}; use tracing::instrument; -use trustify_common::{cpe::Cpe, db::Transactional, purl::Purl}; +use trustify_common::{cpe::Cpe, purl::Purl}; use trustify_entity::relationship::Relationship; use uuid::Uuid; @@ -88,14 +89,12 @@ impl<'a> From> for SbomInformation { } impl SbomContext { - #[instrument(skip(tx, sbom), ret)] - pub async fn ingest_cyclonedx>( + #[instrument(skip(connection, sbom), ret)] + pub async fn ingest_cyclonedx( &self, mut sbom: Bom, - tx: TX, + connection: &C, ) -> Result<(), anyhow::Error> { - let db = &self.graph.db.connection(&tx); - let mut license_creator = LicenseCreator::new(); let mut creator = Creator::new(self.sbom.sbom_id); @@ -121,12 +120,12 @@ impl SbomContext { vendor: component.publisher.clone().map(|p| p.to_string()), cpe: product_cpe, }, - &tx, + connection, ) .await?; if let Some(ver) = component.version.clone() { - pr.ingest_product_version(ver.to_string(), Some(self.sbom.sbom_id), &tx) + pr.ingest_product_version(ver.to_string(), Some(self.sbom.sbom_id), connection) .await?; } @@ -189,8 +188,8 @@ impl SbomContext { // create - license_creator.create(db).await?; - creator.create(db).await?; + license_creator.create(connection).await?; + creator.create(connection).await?; // done diff --git a/modules/ingestor/src/graph/sbom/mod.rs b/modules/ingestor/src/graph/sbom/mod.rs index ae92deeae..37459ffd8 100644 --- a/modules/ingestor/src/graph/sbom/mod.rs +++ b/modules/ingestor/src/graph/sbom/mod.rs @@ -22,8 +22,8 @@ use cpe::uri::OwnedUri; use entity::{product, product_version}; use hex::ToHex; use sea_orm::{ - prelude::Uuid, ActiveModelTrait, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, - QuerySelect, QueryTrait, RelationTrait, Select, SelectColumns, Set, + prelude::Uuid, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, ModelTrait, + QueryFilter, QuerySelect, QueryTrait, RelationTrait, Select, SelectColumns, Set, }; use sea_query::{ extension::postgres::PgExpr, Alias, Condition, Expr, Func, JoinType, Query, SimpleExpr, @@ -35,9 +35,7 @@ use std::{ }; use time::OffsetDateTime; use tracing::instrument; -use trustify_common::{ - cpe::Cpe, db::Transactional, hashing::Digests, purl::Purl, sbom::SbomLocator, -}; +use trustify_common::{cpe::Cpe, hashing::Digests, purl::Purl, sbom::SbomLocator}; use trustify_entity::{ self as entity, labels::Labels, license, package_relates_to_package, purl_license_assertion, relationship::Relationship, sbom, sbom_node, sbom_package, sbom_package_cpe_ref, @@ -65,51 +63,48 @@ impl From<()> for SbomInformation { type SelectEntity = Select; impl Graph { - pub async fn get_sbom_by_id>( + pub async fn get_sbom_by_id( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(sbom::Entity::find_by_id(id) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|sbom| SbomContext::new(self, sbom))) } - #[instrument(skip(tx))] - pub async fn get_sbom_by_digest>( + #[instrument(skip(connection))] + pub async fn get_sbom_by_digest( &self, digest: &str, - tx: TX, + connection: &C, ) -> Result, Error> { - Ok(entity::sbom::Entity::find() - .join( - JoinType::LeftJoin, - entity::sbom::Relation::SourceDocument.def(), - ) + Ok(sbom::Entity::find() + .join(JoinType::LeftJoin, sbom::Relation::SourceDocument.def()) .filter( Condition::any() .add(source_document::Column::Sha256.eq(digest.to_string())) .add(source_document::Column::Sha384.eq(digest.to_string())) .add(source_document::Column::Sha512.eq(digest.to_string())), ) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|sbom| SbomContext::new(self, sbom))) } - #[instrument(skip(tx, info), err(level=tracing::Level::INFO))] - pub async fn ingest_sbom>( + #[instrument(skip(connection, info), err(level=tracing::Level::INFO))] + pub async fn ingest_sbom( &self, labels: impl Into + Debug, digests: &Digests, document_id: &str, info: impl Into, - tx: TX, + connection: &C, ) -> Result { let sha256 = digests.sha256.encode_hex::(); - if let Some(found) = self.get_sbom_by_digest(&sha256, &tx).await? { + if let Some(found) = self.get_sbom_by_digest(&sha256, connection).await? { return Ok(found); } @@ -121,8 +116,6 @@ impl Graph { data_licenses, } = info.into(); - let connection = self.db.connection(&tx); - let sbom_id = Uuid::now_v7(); let doc_model = source_document::ActiveModel { @@ -133,7 +126,7 @@ impl Graph { size: Set(digests.size as i64), }; - let doc = doc_model.insert(&connection).await?; + let doc = doc_model.insert(connection).await?; let model = sbom::ActiveModel { sbom_id: Set(sbom_id), @@ -155,8 +148,8 @@ impl Graph { name: Set(name), }; - let result = model.insert(&connection).await?; - node_model.insert(&connection).await?; + let result = model.insert(connection).await?; + node_model.insert(connection).await?; Ok(SbomContext::new(self, result)) } @@ -169,116 +162,116 @@ impl Graph { /// /// If the requested SBOM does not exist in the fetch, it will not exist /// after this query either. This function is *non-mutating*. - pub async fn locate_sbom>( + pub async fn locate_sbom( &self, sbom_locator: SbomLocator, - tx: TX, + connection: &C, ) -> Result, Error> { match sbom_locator { - SbomLocator::Id(id) => self.locate_sbom_by_id(id, tx).await, - SbomLocator::Sha256(sha256) => self.locate_sbom_by_sha256(&sha256, tx).await, - SbomLocator::Purl(purl) => self.locate_sbom_by_purl(&purl, tx).await, - SbomLocator::Cpe(cpe) => self.locate_sbom_by_cpe22(&cpe, tx).await, + SbomLocator::Id(id) => self.locate_sbom_by_id(id, connection).await, + SbomLocator::Sha256(sha256) => self.locate_sbom_by_sha256(&sha256, connection).await, + SbomLocator::Purl(purl) => self.locate_sbom_by_purl(&purl, connection).await, + SbomLocator::Cpe(cpe) => self.locate_sbom_by_cpe22(&cpe, connection).await, } } - pub async fn locate_sboms>( + pub async fn locate_sboms( &self, sbom_locator: SbomLocator, - tx: TX, + connection: &C, ) -> Result, Error> { match sbom_locator { SbomLocator::Id(id) => { - if let Some(sbom) = self.locate_sbom_by_id(id, tx).await? { + if let Some(sbom) = self.locate_sbom_by_id(id, connection).await? { Ok(vec![sbom]) } else { Ok(vec![]) } } - SbomLocator::Sha256(sha256) => self.locate_sboms_by_sha256(&sha256, tx).await, - SbomLocator::Purl(purl) => self.locate_sboms_by_purl(&purl, tx).await, - SbomLocator::Cpe(cpe) => self.locate_sboms_by_cpe22(cpe, tx).await, + SbomLocator::Sha256(sha256) => self.locate_sboms_by_sha256(&sha256, connection).await, + SbomLocator::Purl(purl) => self.locate_sboms_by_purl(&purl, connection).await, + SbomLocator::Cpe(cpe) => self.locate_sboms_by_cpe22(cpe, connection).await, } } - async fn locate_one_sbom>( + async fn locate_one_sbom( &self, query: SelectEntity, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(query - .one(&self.connection(&tx)) + .one(connection) .await? .map(|sbom| SbomContext::new(self, sbom))) } - pub async fn locate_many_sboms>( + pub async fn locate_many_sboms( &self, query: SelectEntity, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(query - .all(&self.connection(&tx)) + .all(connection) .await? .into_iter() .map(|sbom| SbomContext::new(self, sbom)) .collect()) } - pub async fn locate_sbom_by_id>( + pub async fn locate_sbom_by_id( &self, id: Uuid, - tx: TX, + connection: &C, ) -> Result, Error> { let _query = sbom::Entity::find_by_id(id); Ok(sbom::Entity::find_by_id(id) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|sbom| SbomContext::new(self, sbom))) } - pub async fn locate_sboms_by_labels>( + pub async fn locate_sboms_by_labels( &self, labels: Labels, - tx: TX, + connection: &C, ) -> Result, Error> { self.locate_many_sboms( sbom::Entity::find().filter(Expr::col(sbom::Column::Labels).contains(labels)), - tx, + connection, ) .await } - async fn locate_sbom_by_sha256>( + async fn locate_sbom_by_sha256( &self, sha256: &str, - tx: TX, + connection: &C, ) -> Result, Error> { self.locate_one_sbom( sbom::Entity::find() .join(JoinType::Join, sbom::Relation::SourceDocument.def()) .filter(source_document::Column::Sha256.eq(sha256.to_string())), - tx, + connection, ) .await } - async fn locate_sboms_by_sha256>( + async fn locate_sboms_by_sha256( &self, sha256: &str, - tx: TX, + connection: &C, ) -> Result, Error> { self.locate_many_sboms( sbom::Entity::find() .join(JoinType::Join, sbom::Relation::SourceDocument.def()) .filter(source_document::Column::Sha256.eq(sha256.to_string())), - tx, + connection, ) .await } fn query_by_purl(package: QualifiedPackageContext) -> Select { - entity::sbom::Entity::find() + sbom::Entity::find() .join_rev(JoinType::Join, sbom_package::Relation::Sbom.def()) .join_rev( JoinType::Join, @@ -288,7 +281,7 @@ impl Graph { } fn query_by_cpe(cpe: CpeContext) -> Select { - entity::sbom::Entity::find() + sbom::Entity::find() .join_rev(JoinType::Join, sbom_package::Relation::Sbom.def()) .join_rev( JoinType::Join, @@ -297,58 +290,60 @@ impl Graph { .filter(sbom_package_cpe_ref::Column::CpeId.eq(cpe.cpe.id)) } - async fn locate_sbom_by_purl>( + async fn locate_sbom_by_purl( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - let package = self.get_qualified_package(purl, &tx).await?; + let package = self.get_qualified_package(purl, connection).await?; if let Some(package) = package { - self.locate_one_sbom(Self::query_by_purl(package), &tx) + self.locate_one_sbom(Self::query_by_purl(package), connection) .await } else { Ok(None) } } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - async fn locate_sboms_by_purl>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + async fn locate_sboms_by_purl( &self, purl: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - let package = self.get_qualified_package(purl, &tx).await?; + let package = self.get_qualified_package(purl, connection).await?; if let Some(package) = package { - self.locate_many_sboms(Self::query_by_purl(package), &tx) + self.locate_many_sboms(Self::query_by_purl(package), connection) .await } else { Ok(vec![]) } } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - async fn locate_sbom_by_cpe22>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + async fn locate_sbom_by_cpe22( &self, cpe: &Cpe, - tx: TX, + connection: &C, ) -> Result, Error> { - if let Some(cpe) = self.get_cpe(cpe.clone(), &tx).await? { - self.locate_one_sbom(Self::query_by_cpe(cpe), &tx).await + if let Some(cpe) = self.get_cpe(cpe.clone(), connection).await? { + self.locate_one_sbom(Self::query_by_cpe(cpe), connection) + .await } else { Ok(None) } } - #[instrument(skip(self, tx), err)] - async fn locate_sboms_by_cpe22(&self, cpe: C, tx: TX) -> Result, Error> - where - C: Into + Debug, - TX: AsRef, - { - if let Some(cpe) = self.get_cpe(cpe, &tx).await? { - self.locate_many_sboms(Self::query_by_cpe(cpe), &tx).await + #[instrument(skip(self, connection), err)] + async fn locate_sboms_by_cpe22( + &self, + cpe: impl Into + Debug, + connection: &C, + ) -> Result, Error> { + if let Some(cpe) = self.get_cpe(cpe, connection).await? { + self.locate_many_sboms(Self::query_by_cpe(cpe), connection) + .await } else { Ok(vec![]) } @@ -423,14 +418,16 @@ impl SbomContext { } } - pub async fn ingest_purl_license_assertion>( + pub async fn ingest_purl_license_assertion( &self, purl: &Purl, license: &str, - tx: TX, + connection: &C, ) -> Result<(), Error> { - let connection = self.graph.connection(&tx); - let purl = self.graph.ingest_qualified_package(purl, &tx).await?; + let purl = self + .graph + .ingest_qualified_package(purl, connection) + .await?; let license_info = LicenseInfo { license: license.to_string(), @@ -440,7 +437,7 @@ impl SbomContext { let (spdx_licenses, spdx_exceptions) = license_info.spdx_info(); let license = license::Entity::find_by_id(license_info.uuid()) - .one(&connection) + .one(connection) .await?; let license = if let Some(license) = license { @@ -460,7 +457,7 @@ impl SbomContext { Set(Some(spdx_exceptions)) }, } - .insert(&connection) + .insert(connection) .await? }; @@ -471,7 +468,7 @@ impl SbomContext { .eq(purl.package_version.package_version.id), ) .filter(purl_license_assertion::Column::SbomId.eq(self.sbom.sbom_id)) - .one(&connection) + .one(connection) .await?; if assertion.is_none() { @@ -481,7 +478,7 @@ impl SbomContext { versioned_purl_id: Set(purl.package_version.package_version.id), sbom_id: Set(self.sbom.sbom_id), } - .insert(&connection) + .insert(connection) .await?; } @@ -510,10 +507,10 @@ impl SbomContext { } /// Get the PURLs which describe an SBOM - #[instrument(skip(tx), err)] - pub async fn describes_purls>( + #[instrument(skip(connection), err)] + pub async fn describes_purls( &self, - tx: TX, + connection: &C, ) -> Result, Error> { let describes = self.query_describes_packages(); @@ -523,16 +520,16 @@ impl SbomContext { .join(JoinType::Join, sbom_package::Relation::Purl.def()) .select_column(sbom_package_purl_ref::Column::QualifiedPurlId) .into_query(), - tx, + connection, ) .await } /// Get the CPEs which describe an SBOM - #[instrument(skip(tx), err)] - pub async fn describes_cpe22s>( + #[instrument(skip(connection), err)] + pub async fn describes_cpe22s( &self, - tx: TX, + connection: &C, ) -> Result, Error> { let describes = self.query_describes_packages(); @@ -542,16 +539,16 @@ impl SbomContext { .join(JoinType::Join, sbom_package::Relation::Cpe.def()) .select_column(sbom_package_cpe_ref::Column::CpeId) .into_query(), - tx, + connection, ) .await } /* #[instrument(skip(tx), err)] - pub async fn packages>( + pub async fn packages( &self, - tx: TX, + connection: &C, ) -> Result, Error> { self.graph .get_qualified_packages_by_query( @@ -572,13 +569,13 @@ impl SbomContext { /// The packages will be created if they don't yet exist. /// /// **NOTE:** This is a convenience function, creating relationships for tests. It is terribly slow. - #[instrument(skip(tx), err)] - pub async fn ingest_package_relates_to_package<'a, TX: AsRef>( + #[instrument(skip(connection), err)] + pub async fn ingest_package_relates_to_package<'a, C: ConnectionTrait>( &'a self, left: impl Into + Debug, relationship: Relationship, right: impl Into + Debug, - tx: TX, + connection: &C, ) -> Result<(), Error> { let left = left.into(); let right = right.into(); @@ -597,7 +594,7 @@ impl SbomContext { ) } RelationshipReference::Cpe(cpe) => { - let cpe_ctx = self.graph.ingest_cpe22(cpe.clone(), &tx).await?; + let cpe_ctx = self.graph.ingest_cpe22(cpe.clone(), connection).await?; (Some(cpe.to_string()), vec![], vec![cpe_ctx.cpe.id]) } }; @@ -612,12 +609,12 @@ impl SbomContext { ) } RelationshipReference::Cpe(cpe) => { - let cpe_ctx = self.graph.ingest_cpe22(cpe.clone(), &tx).await?; + let cpe_ctx = self.graph.ingest_cpe22(cpe.clone(), connection).await?; (Some(cpe.to_string()), vec![], vec![cpe_ctx.cpe.id]) } }; - creator.create(&self.graph.connection(&tx)).await?; + creator.create(connection).await?; // create the nodes @@ -628,7 +625,7 @@ impl SbomContext { None, left_purls, left_cpes, - &tx, + connection, ) .await?; } @@ -640,7 +637,7 @@ impl SbomContext { None, right_purls, right_cpes, - &tx, + connection, ) .await?; } @@ -652,38 +649,38 @@ impl SbomContext { let mut relationships = RelationshipCreator::new(self.sbom.sbom_id); relationships.relate(left_node_id, relationship, right_node_id); - relationships.create(&self.graph.db.connection(&tx)).await?; + relationships.create(connection).await?; Ok(()) } - #[instrument(skip(self, tx), err)] - pub async fn ingest_describes_package>( + #[instrument(skip(self, connection), err)] + pub async fn ingest_describes_package( &self, package: Purl, - tx: TX, + connection: &C, ) -> anyhow::Result<()> { self.ingest_package_relates_to_package( RelationshipReference::Root, Relationship::DescribedBy, RelationshipReference::Purl(package), - tx, + connection, ) .await?; Ok(()) } - #[instrument(skip(self, tx), err)] - pub async fn ingest_describes_cpe22>( + #[instrument(skip(self, connection), err)] + pub async fn ingest_describes_cpe22( &self, cpe: Cpe, - tx: TX, + connection: &C, ) -> anyhow::Result<()> { self.ingest_package_relates_to_package( RelationshipReference::Root, Relationship::DescribedBy, RelationshipReference::Cpe(cpe), - tx, + connection, ) .await?; Ok(()) @@ -693,15 +690,15 @@ impl SbomContext { /// /// **NOTE:** This function ingests a single package, and is terribly slow. /// Use the [`PackageCreator`] for creating more than one. - #[instrument(skip(self, tx), err)] - async fn ingest_package>( + #[instrument(skip(self, connection), err)] + async fn ingest_package( &self, node_id: String, name: String, version: Option, purls: Vec<(Uuid, Uuid)>, cpes: Vec, - tx: TX, + connection: &C, ) -> Result<(), Error> { let mut creator = PackageCreator::new(self.sbom.sbom_id); @@ -714,21 +711,21 @@ impl SbomContext { .chain(cpes.into_iter().map(PackageReference::Cpe)); creator.add(node_id, name, version, refs, iter::empty()); - creator.create(&self.graph.db.connection(&tx)).await?; + creator.create(connection).await?; // done Ok(()) } - #[instrument(skip(self, tx), err)] - pub async fn related_packages_transitively>( + #[instrument(skip(self, connection), err)] + pub async fn related_packages_transitively( &self, relationships: &[Relationship], pkg: &Purl, - tx: TX, + connection: &C, ) -> Result, Error> { - let pkg = self.graph.get_qualified_package(pkg, &tx).await?; + let pkg = self.graph.get_qualified_package(pkg, connection).await?; if let Some(pkg) = pkg { let rels: SimpleExpr = relationships @@ -754,7 +751,7 @@ impl SbomContext { QualifiedPackageTransitive, ) .to_owned(), - &tx, + connection, ) .await?) } else { @@ -762,31 +759,27 @@ impl SbomContext { } } - pub async fn link_to_product<'a, TX: AsRef>( + pub async fn link_to_product<'a, C: ConnectionTrait>( &self, product_version: ProductVersionContext<'a>, - tx: TX, + connection: &C, ) -> Result, Error> { let mut entity = product_version::ActiveModel::from(product_version.product_version); entity.sbom_id = Set(Some(self.sbom.sbom_id)); - let model = entity.update(&self.graph.connection(&tx)).await?; + let model = entity.update(connection).await?; Ok(ProductVersionContext::new(&product_version.product, model)) } - pub async fn get_product>( + pub async fn get_product( &self, - tx: TX, + connection: &C, ) -> Result, Error> { if let Some(vers) = product_version::Entity::find() .filter(product_version::Column::SbomId.eq(self.sbom.sbom_id)) - .one(&self.graph.connection(&tx)) + .one(connection) .await? { - if let Some(prod) = vers - .find_related(product::Entity) - .one(&self.graph.connection(&tx)) - .await? - { + if let Some(prod) = vers.find_related(product::Entity).one(connection).await? { Ok(Some(ProductVersionContext::new( &ProductContext::new(&self.graph, prod), vers, diff --git a/modules/ingestor/src/graph/sbom/spdx.rs b/modules/ingestor/src/graph/sbom/spdx.rs index 522c2c3e8..a4a9e9d10 100644 --- a/modules/ingestor/src/graph/sbom/spdx.rs +++ b/modules/ingestor/src/graph/sbom/spdx.rs @@ -11,12 +11,13 @@ use crate::{ service::Error, }; use sbom_walker::report::{check, ReportSink}; +use sea_orm::ConnectionTrait; use serde_json::Value; use spdx_rs::models::{RelationshipType, SPDX}; use std::{collections::HashMap, str::FromStr}; use time::OffsetDateTime; use tracing::instrument; -use trustify_common::{cpe::Cpe, db::Transactional, purl::Purl}; +use trustify_common::{cpe::Cpe, purl::Purl}; use trustify_entity::relationship::Relationship; pub struct Information<'a>(pub &'a SPDX); @@ -49,12 +50,12 @@ impl<'a> From> for SbomInformation { } impl SbomContext { - #[instrument(skip(tx, sbom_data, warnings), ret)] - pub async fn ingest_spdx>( + #[instrument(skip(db, sbom_data, warnings), ret)] + pub async fn ingest_spdx( &self, sbom_data: SPDX, warnings: &dyn ReportSink, - tx: TX, + db: &C, ) -> Result<(), Error> { // pre-flight checks @@ -191,12 +192,12 @@ impl SbomContext { vendor: package.package_supplier.clone(), cpe: product_cpe, }, - &tx, + db, ) .await?; if let Some(ver) = package.package_version.clone() { - pr.ingest_product_version(ver, Some(self.sbom.sbom_id), &tx) + pr.ingest_product_version(ver, Some(self.sbom.sbom_id), db) .await?; } } @@ -211,15 +212,11 @@ impl SbomContext { files.add(file.file_spdx_identifier, file.file_name); } - // get database connection - - let db = self.graph.connection(&tx); - // create all purls and CPEs - licenses.create(&db).await?; - purls.create(&db).await?; - cpes.create(&db).await?; + licenses.create(db).await?; + purls.create(db).await?; + cpes.create(db).await?; // validate relationships before inserting @@ -235,9 +232,9 @@ impl SbomContext { // create packages, files, and relationships - packages.create(&db).await?; - files.create(&db).await?; - relationships.create(&db).await?; + packages.create(db).await?; + files.create(db).await?; + relationships.create(db).await?; // done diff --git a/modules/ingestor/src/graph/vulnerability/mod.rs b/modules/ingestor/src/graph/vulnerability/mod.rs index f08a2cfd7..bf7cab020 100644 --- a/modules/ingestor/src/graph/vulnerability/mod.rs +++ b/modules/ingestor/src/graph/vulnerability/mod.rs @@ -3,13 +3,14 @@ use crate::common::{Deprecation, DeprecationExt}; use crate::{graph::advisory::AdvisoryContext, graph::error::Error, graph::Graph}; use sea_orm::{ - ActiveValue::Set, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, QuerySelect, RelationTrait, + ActiveValue::Set, ColumnTrait, ConnectionTrait, EntityTrait, ModelTrait, QueryFilter, + QuerySelect, RelationTrait, }; use sea_query::{JoinType, OnConflict}; use std::fmt::{Debug, Formatter}; use time::OffsetDateTime; use tracing::instrument; -use trustify_common::db::{chunk::EntityChunkedIter, Transactional}; +use trustify_common::db::chunk::EntityChunkedIter; use trustify_entity::{advisory, advisory_vulnerability, vulnerability, vulnerability_description}; use uuid::Uuid; @@ -48,15 +49,14 @@ impl From<()> for VulnerabilityInformation { } impl Graph { - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn ingest_vulnerability( + #[instrument(skip(self, db), err(level=tracing::Level::INFO))] + pub async fn ingest_vulnerability( &self, identifier: &str, information: impl Into + Debug, - tx: impl AsRef, + db: &C, ) -> Result { let information = information.into(); - let db = self.connection(&tx); let mut on_conflict = OnConflict::columns([vulnerability::Column::Id]); let on_conflict = match information.has_data() { @@ -88,31 +88,31 @@ impl Graph { let result = vulnerability::Entity::insert(entity) .on_conflict(on_conflict) - .exec_with_returning(&db) + .exec_with_returning(db) .await?; Ok(VulnerabilityContext::new(self, result)) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_vulnerability( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_vulnerability( &self, identifier: &str, - tx: impl AsRef, + connection: &C, ) -> Result, Error> { Ok(vulnerability::Entity::find_by_id(identifier) - .one(&self.connection(&tx)) + .one(connection) .await? .map(|vuln| VulnerabilityContext::new(self, vuln))) } - #[instrument(skip(self, tx), err(level=tracing::Level::INFO))] - pub async fn get_vulnerabilities>( + #[instrument(skip(self, connection), err(level=tracing::Level::INFO))] + pub async fn get_vulnerabilities( &self, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(vulnerability::Entity::find() - .all(&self.connection(&tx)) + .all(connection) .await? .into_iter() .map(|vulnerability| VulnerabilityContext::new(self, vulnerability)) @@ -140,10 +140,10 @@ impl VulnerabilityContext { } } - pub async fn advisories>( + pub async fn advisories( &self, deprecation: Deprecation, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(advisory::Entity::find() .with_deprecation(deprecation) @@ -154,29 +154,29 @@ impl VulnerabilityContext { .filter( advisory_vulnerability::Column::VulnerabilityId.eq(self.vulnerability.id.clone()), ) - .all(&self.graph.connection(&tx)) + .all(connection) .await? .drain(0..) .map(|advisory| AdvisoryContext::new(&self.graph, advisory)) .collect()) } - pub async fn add_description>( + pub async fn add_description( &self, advisory: Uuid, lang: &str, description: &str, - tx: TX, + connection: &C, ) -> Result<(), Error> { - self.add_descriptions(advisory, [(lang, description)], tx) + self.add_descriptions(advisory, [(lang, description)], connection) .await } - pub async fn add_descriptions>( + pub async fn add_descriptions( &self, advisory: Uuid, descriptions: impl IntoIterator, - tx: TX, + connection: &C, ) -> Result<(), Error> { let entries = descriptions.into_iter().map(|(lang, description)| { vulnerability_description::ActiveModel { @@ -188,43 +188,39 @@ impl VulnerabilityContext { } }); - let db = self.graph.connection(&tx); - for batch in &entries.chunked() { vulnerability_description::Entity::insert_many(batch) - .exec(&db) + .exec(connection) .await?; } Ok(()) } - pub async fn drop_descriptions_for_advisory>( + pub async fn drop_descriptions_for_advisory( &self, advisory: Uuid, - tx: TX, + connection: &C, ) -> Result<(), Error> { - let db = self.graph.connection(&tx); - vulnerability_description::Entity::delete_many() .filter(vulnerability_description::Column::AdvisoryId.eq(advisory)) - .exec(&db) + .exec(connection) .await?; Ok(()) } // Get all descriptions for a vulnerability - pub async fn descriptions>( + pub async fn descriptions( &self, lang: &str, - tx: TX, + connection: &C, ) -> Result, Error> { Ok(self .vulnerability .find_related(vulnerability_description::Entity) .filter(vulnerability_description::Column::Lang.eq(lang)) - .all(&self.graph.connection(&tx)) + .all(connection) .await? .drain(..) .map(|e| e.description) @@ -238,32 +234,22 @@ mod tests { use crate::graph::Graph; use test_context::test_context; use test_log::test; - use trustify_common::db::Transactional; use trustify_common::hashing::Digests; use trustify_test_context::TrustifyContext; #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn ingest_cves(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let system = Graph::new(db); + let system = Graph::new(ctx.db.clone()); - let cve1 = system - .ingest_vulnerability("CVE-123", (), Transactional::None) - .await?; - let cve2 = system - .ingest_vulnerability("CVE-123", (), Transactional::None) - .await?; - let cve3 = system - .ingest_vulnerability("CVE-456", (), Transactional::None) - .await?; + let cve1 = system.ingest_vulnerability("CVE-123", (), &ctx.db).await?; + let cve2 = system.ingest_vulnerability("CVE-123", (), &ctx.db).await?; + let cve3 = system.ingest_vulnerability("CVE-456", (), &ctx.db).await?; assert_eq!(cve1.vulnerability.id, cve2.vulnerability.id); assert_ne!(cve1.vulnerability.id, cve3.vulnerability.id); - let not_found = system - .get_vulnerability("CVE-NOT_FOUND", Transactional::None) - .await?; + let not_found = system.get_vulnerability("CVE-NOT_FOUND", &ctx.db).await?; assert!(not_found.is_none()); @@ -280,7 +266,7 @@ mod tests { ("source", "http://ghsa.io/GHSA-1"), &Digests::digest("GHSA-1"), (), - Transactional::None, + &ctx.db, ) .await?; @@ -291,7 +277,7 @@ mod tests { ("source", "http://rhsa.io/RHSA-1"), &Digests::digest("RHSA-1"), (), - Transactional::None, + &ctx.db, ) .await?; @@ -302,34 +288,29 @@ mod tests { ("source", "http://snyk.io/SNYK-1"), &Digests::digest("SNYK-1"), (), - Transactional::None, + &ctx.db, ) .await?; advisory1 - .link_to_vulnerability("CVE-8675309", None, Transactional::None) + .link_to_vulnerability("CVE-8675309", None, &ctx.db) .await?; advisory2 - .link_to_vulnerability("CVE-8675309", None, Transactional::None) + .link_to_vulnerability("CVE-8675309", None, &ctx.db) .await?; ctx.graph - .ingest_vulnerability("CVE-8675309", (), Transactional::None) + .ingest_vulnerability("CVE-8675309", (), &ctx.db) .await?; - let cve = ctx - .graph - .get_vulnerability("CVE-8675309", Transactional::None) - .await?; + let cve = ctx.graph.get_vulnerability("CVE-8675309", &ctx.db).await?; assert!(cve.is_some(), "there should be some CVE"); let cve = cve.unwrap(); - let linked_advisories = cve - .advisories(Default::default(), Transactional::None) - .await?; + let linked_advisories = cve.advisories(Default::default(), &ctx.db).await?; assert_eq!(2, linked_advisories.len()); diff --git a/modules/ingestor/src/service/advisory/csaf/creator.rs b/modules/ingestor/src/service/advisory/csaf/creator.rs index b927d264c..5221d059f 100644 --- a/modules/ingestor/src/service/advisory/csaf/creator.rs +++ b/modules/ingestor/src/service/advisory/csaf/creator.rs @@ -16,13 +16,10 @@ use sea_orm::{ActiveValue::Set, ColumnTrait, ConnectionTrait, EntityTrait, Query use sea_query::IntoCondition; use std::collections::{hash_map::Entry, HashMap, HashSet}; use tracing::instrument; -use trustify_common::{ - cpe::Cpe, - db::{chunk::EntityChunkedIter, Transactional}, - purl::Purl, +use trustify_common::{cpe::Cpe, db::chunk::EntityChunkedIter, purl::Purl}; +use trustify_entity::{ + product_status, purl_status, status, version_range, version_scheme::VersionScheme, }; -use trustify_entity::version_scheme::VersionScheme; -use trustify_entity::{product_status, purl_status, status, version_range}; use uuid::Uuid; #[derive(Debug, Eq, Hash, PartialEq)] @@ -101,13 +98,11 @@ impl<'a> StatusCreator<'a> { } #[instrument(skip_all, ret)] - pub async fn create>( + pub async fn create( &mut self, graph: &Graph, - tx: TX, + connection: &C, ) -> Result<(), Error> { - let connection = &graph.connection(&tx); - let mut checked = HashMap::new(); let mut product_statuses = Vec::new(); let mut purls = PurlCreator::new(); @@ -133,14 +128,14 @@ impl<'a> StatusCreator<'a> { vendor: product.vendor.clone(), cpe: product.cpe.clone(), }, - &tx, + connection, ) .await?; // Ingest product range let product_version_range = match product.version { Some(ref ver) => Some( - pr.ingest_product_version_range(ver.clone(), None, &tx) + pr.ingest_product_version_range(ver.clone(), None, connection) .await?, ), None => None, diff --git a/modules/ingestor/src/service/advisory/csaf/loader.rs b/modules/ingestor/src/service/advisory/csaf/loader.rs index 0400cc08e..0bfa837cd 100644 --- a/modules/ingestor/src/service/advisory/csaf/loader.rs +++ b/modules/ingestor/src/service/advisory/csaf/loader.rs @@ -15,12 +15,13 @@ use csaf::{ Csaf, }; use sbom_walker::report::ReportSink; +use sea_orm::{ConnectionTrait, TransactionTrait}; use semver::Version; use std::fmt::Debug; use std::str::FromStr; use time::OffsetDateTime; use tracing::instrument; -use trustify_common::{db::Transactional, hashing::Digests, id::Id}; +use trustify_common::{hashing::Digests, id::Id}; use trustify_cvss::cvss3::Cvss3Base; use trustify_entity::labels::Labels; @@ -88,7 +89,7 @@ impl<'g> CsafLoader<'g> { ) -> Result { let warnings = Warnings::new(); - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let advisory_id = gen_identifier(&csaf); let labels = labels.into().add("type", "csaf"); @@ -118,19 +119,21 @@ impl<'g> CsafLoader<'g> { cve=vulnerability.cve ) )] - async fn ingest_vulnerability>( + async fn ingest_vulnerability( &self, csaf: &Csaf, advisory: &AdvisoryContext<'_>, vulnerability: &Vulnerability, report: &dyn ReportSink, - tx: TX, + connection: &C, ) -> Result<(), Error> { let Some(cve_id) = &vulnerability.cve else { return Ok(()); }; - self.graph.ingest_vulnerability(cve_id, (), &tx).await?; + self.graph + .ingest_vulnerability(cve_id, (), connection) + .await?; let advisory_vulnerability = advisory .link_to_vulnerability( @@ -148,12 +151,12 @@ impl<'g> CsafLoader<'g> { }), cwes: vulnerability.cwe.as_ref().map(|cwe| vec![cwe.id.clone()]), }), - &tx, + connection, ) .await?; if let Some(product_status) = &vulnerability.product_status { - self.ingest_product_statuses(csaf, &advisory_vulnerability, product_status, &tx) + self.ingest_product_statuses(csaf, &advisory_vulnerability, product_status, connection) .await?; } @@ -163,7 +166,7 @@ impl<'g> CsafLoader<'g> { Ok(cvss3) => { log::debug!("{cvss3:?}"); advisory_vulnerability - .ingest_cvss3_score(cvss3, &tx) + .ingest_cvss3_score(cvss3, connection) .await?; } Err(err) => { @@ -179,12 +182,12 @@ impl<'g> CsafLoader<'g> { } #[instrument(skip_all, err)] - async fn ingest_product_statuses>( + async fn ingest_product_statuses( &self, csaf: &Csaf, advisory_vulnerability: &AdvisoryVulnerabilityContext<'_>, product_status: &ProductStatus, - tx: TX, + connection: &C, ) -> Result<(), Error> { let mut creator = StatusCreator::new( csaf, @@ -199,7 +202,7 @@ impl<'g> CsafLoader<'g> { creator.add_all(&product_status.known_not_affected, "not_affected"); creator.add_all(&product_status.known_affected, "affected"); - creator.create(self.graph, tx).await?; + creator.create(self.graph, connection).await?; Ok(()) } @@ -213,14 +216,12 @@ mod test { use crate::graph::Graph; use test_context::test_context; use test_log::test; - use trustify_common::db::Transactional; use trustify_test_context::{document, TrustifyContext}; #[test_context(TrustifyContext)] #[test(tokio::test)] async fn loader(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - let db = &ctx.db; - let graph = Graph::new(db.clone()); + let graph = Graph::new(ctx.db.clone()); let (csaf, digests): (Csaf, _) = document("csaf/CVE-2023-20862.json").await?; let loader = CsafLoader::new(&graph); @@ -228,13 +229,11 @@ mod test { .load(("file", "CVE-2023-20862.json"), csaf, &digests) .await?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2023-20862", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2023-20862", &ctx.db).await?; assert!(loaded_vulnerability.is_some()); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_some()); @@ -242,7 +241,7 @@ mod test { assert!(loaded_advisory.advisory.issuer_id.is_some()); - let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(()).await?; + let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(&ctx.db).await?; assert_eq!(1, loaded_advisory_vulnerabilities.len()); // let loaded_advisory_vulnerability = &loaded_advisory_vulnerabilities[0]; @@ -277,12 +276,12 @@ mod test { // )); let advisory_vuln = loaded_advisory - .get_vulnerability("CVE-2023-20862", ()) + .get_vulnerability("CVE-2023-20862", &ctx.db) .await?; assert!(advisory_vuln.is_some()); let advisory_vuln = advisory_vuln.unwrap(); - let scores = advisory_vuln.cvss3_scores(()).await?; + let scores = advisory_vuln.cvss3_scores(&ctx.db).await?; assert_eq!(1, scores.len()); let score = scores[0]; @@ -297,20 +296,17 @@ mod test { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn multiple_vulnerabilities(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let graph = Graph::new(db); + let graph = Graph::new(ctx.db.clone()); let loader = CsafLoader::new(&graph); let (csaf, digests): (Csaf, _) = document("csaf/rhsa-2024_3666.json").await?; loader.load(("source", "test"), csaf, &digests).await?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2024-23672", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2024-23672", &ctx.db).await?; assert!(loaded_vulnerability.is_some()); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_some()); @@ -318,16 +314,16 @@ mod test { assert!(loaded_advisory.advisory.issuer_id.is_some()); - let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(()).await?; + let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(&ctx.db).await?; assert_eq!(2, loaded_advisory_vulnerabilities.len()); let advisory_vuln = loaded_advisory - .get_vulnerability("CVE-2024-23672", ()) + .get_vulnerability("CVE-2024-23672", &ctx.db) .await?; assert!(advisory_vuln.is_some()); let advisory_vuln = advisory_vuln.unwrap(); - let scores = advisory_vuln.cvss3_scores(()).await?; + let scores = advisory_vuln.cvss3_scores(&ctx.db).await?; assert_eq!(1, scores.len()); let score = scores[0]; @@ -341,20 +337,17 @@ mod test { #[test_context(TrustifyContext, skip_teardown)] #[test(tokio::test)] async fn product_status(ctx: TrustifyContext) -> Result<(), anyhow::Error> { - let db = ctx.db; - let graph = Graph::new(db); + let graph = Graph::new(ctx.db.clone()); let loader = CsafLoader::new(&graph); let (csaf, digests): (Csaf, _) = document("csaf/cve-2023-0044.json").await?; loader.load(("source", "test"), csaf, &digests).await?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2023-0044", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2023-0044", &ctx.db).await?; assert!(loaded_vulnerability.is_some()); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_some()); @@ -362,16 +355,16 @@ mod test { assert!(loaded_advisory.advisory.issuer_id.is_some()); - let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(()).await?; + let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(&ctx.db).await?; assert_eq!(1, loaded_advisory_vulnerabilities.len()); let advisory_vuln = loaded_advisory - .get_vulnerability("CVE-2023-0044", ()) + .get_vulnerability("CVE-2023-0044", &ctx.db) .await?; assert!(advisory_vuln.is_some()); let advisory_vuln = advisory_vuln.unwrap(); - let scores = advisory_vuln.cvss3_scores(()).await?; + let scores = advisory_vuln.cvss3_scores(&ctx.db).await?; assert_eq!(1, scores.len()); let score = scores[0]; diff --git a/modules/ingestor/src/service/advisory/cve/loader.rs b/modules/ingestor/src/service/advisory/cve/loader.rs index 98d58dfd1..d80d26edf 100644 --- a/modules/ingestor/src/service/advisory/cve/loader.rs +++ b/modules/ingestor/src/service/advisory/cve/loader.rs @@ -14,6 +14,7 @@ use cve::{ common::{Description, Product, Status, VersionRange}, Cve, Timestamp, }; +use sea_orm::TransactionTrait; use std::fmt::Debug; use time::OffsetDateTime; use tracing::instrument; @@ -49,7 +50,7 @@ impl<'g> CveLoader<'g> { let id = cve.id(); let labels = labels.into().add("type", "cve"); - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let VulnerabilityDetails { org_name, @@ -334,7 +335,6 @@ mod test { use test_context::test_context; use test_log::test; use time::macros::datetime; - use trustify_common::db::Transactional; use trustify_common::purl::Purl; use trustify_test_context::{document, TrustifyContext}; @@ -345,11 +345,11 @@ mod test { let (cve, digests): (Cve, _) = document("mitre/CVE-2024-28111.json").await?; - let loaded_vulnerability = graph.get_vulnerability("CVE-2024-28111", ()).await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2024-28111", &ctx.db).await?; assert!(loaded_vulnerability.is_none()); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_none()); @@ -358,7 +358,7 @@ mod test { .load(("file", "CVE-2024-28111.json"), cve, &digests) .await?; - let loaded_vulnerability = graph.get_vulnerability("CVE-2024-28111", ()).await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2024-28111", &ctx.db).await?; assert!(loaded_vulnerability.is_some()); let loaded_vulnerability = loaded_vulnerability.unwrap(); assert_eq!( @@ -367,11 +367,11 @@ mod test { ); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_some()); - let descriptions = loaded_vulnerability.descriptions("en", ()).await?; + let descriptions = loaded_vulnerability.descriptions("en", &ctx.db).await?; assert_eq!(1, descriptions.len()); assert!(descriptions[0] .starts_with("Canarytokens helps track activity and actions on a network")); @@ -394,7 +394,7 @@ mod test { let purl = graph .get_package( &Purl::from_str("pkg:maven/org.apache.commons/commons-compress")?, - Transactional::None, + &ctx.db, ) .await?; diff --git a/modules/ingestor/src/service/advisory/osv/loader.rs b/modules/ingestor/src/service/advisory/osv/loader.rs index e924c66ac..520997096 100644 --- a/modules/ingestor/src/service/advisory/osv/loader.rs +++ b/modules/ingestor/src/service/advisory/osv/loader.rs @@ -17,9 +17,10 @@ use crate::{ }; use osv::schema::{Event, Range, RangeType, ReferenceType, SeverityType, Vulnerability}; use sbom_walker::report::ReportSink; +use sea_orm::{ConnectionTrait, TransactionTrait}; use std::{fmt::Debug, str::FromStr}; use tracing::instrument; -use trustify_common::{db::Transactional, hashing::Digests, id::Id, purl::Purl, time::ChronoExt}; +use trustify_common::{hashing::Digests, id::Id, purl::Purl, time::ChronoExt}; use trustify_cvss::cvss3::Cvss3Base; use trustify_entity::{labels::Labels, version_scheme::VersionScheme}; @@ -46,7 +47,7 @@ impl<'g> OsvLoader<'g> { let issuer = issuer.or(detect_organization(&osv)); - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let cve_ids = osv.aliases.iter().flat_map(|aliases| { aliases @@ -162,7 +163,7 @@ impl<'g> OsvLoader<'g> { } } - purl_creator.create(&self.graph.connection(&tx)).await?; + purl_creator.create(&tx).await?; tx.commit().await?; @@ -175,12 +176,12 @@ impl<'g> OsvLoader<'g> { } /// create package statues based on listed versions -async fn create_package_status_versions( +async fn create_package_status_versions( advisory_vuln: &AdvisoryVulnerabilityContext<'_>, purl: &Purl, range: &Range, versions: impl IntoIterator, - tx: impl AsRef, + connection: &C, ) -> Result<(), Error> { // the list of versions, sorted by the range type let versions = versions.into_iter().cloned().collect::>(); @@ -200,12 +201,12 @@ async fn create_package_status_versions( start, Some(version), &versions, - &tx, + connection, ) .await?; } - ingest_exact(advisory_vuln, purl, "fixed", version, &tx).await?; + ingest_exact(advisory_vuln, purl, "fixed", version, connection).await?; } Event::Limit(_) => {} // for non_exhaustive @@ -214,14 +215,23 @@ async fn create_package_status_versions( } if let Some(start) = start { - ingest_range_from(advisory_vuln, purl, "affected", start, None, &versions, &tx).await?; + ingest_range_from( + advisory_vuln, + purl, + "affected", + start, + None, + &versions, + connection, + ) + .await?; } Ok(()) } /// Ingest all from a start to an end -async fn ingest_range_from( +async fn ingest_range_from( advisory_vuln: &AdvisoryVulnerabilityContext<'_>, purl: &Purl, status: &str, @@ -229,12 +239,12 @@ async fn ingest_range_from( // exclusive end end: Option<&str>, versions: &[impl AsRef], - tx: impl AsRef, + connection: &C, ) -> Result<(), Error> { let versions = match_versions(versions, start, end); for version in versions { - ingest_exact(advisory_vuln, purl, status, version, &tx).await?; + ingest_exact(advisory_vuln, purl, status, version, connection).await?; } Ok(()) @@ -275,12 +285,12 @@ fn match_versions<'v>( } /// Ingest an exact version -async fn ingest_exact( +async fn ingest_exact( advisory_vuln: &AdvisoryVulnerabilityContext<'_>, purl: &Purl, status: &str, version: &str, - tx: impl AsRef, + connection: &C, ) -> Result<(), Error> { Ok(advisory_vuln .ingest_package_status( @@ -291,17 +301,17 @@ async fn ingest_exact( scheme: VersionScheme::Generic, spec: VersionSpec::Exact(version.to_string()), }, - &tx, + connection, ) .await?) } /// create a package status from a semver range -async fn create_package_status_semver( +async fn create_package_status_semver( advisory_vuln: &AdvisoryVulnerabilityContext<'_>, purl: &Purl, range: &Range, - tx: impl AsRef, + connection: &C, ) -> Result<(), Error> { let parsed_range = events_to_range(&range.events); @@ -331,7 +341,7 @@ async fn create_package_status_semver( scheme: VersionScheme::Semver, spec, }, - &tx, + connection, ) .await?; } @@ -346,7 +356,7 @@ async fn create_package_status_semver( scheme: VersionScheme::Semver, spec: VersionSpec::Exact(fixed.clone()), }, - &tx, + connection, ) .await? } @@ -390,20 +400,16 @@ fn events_to_range(events: &[Event]) -> (Option, Option) { #[cfg(test)] mod test { + use super::*; + use crate::graph::Graph; + use crate::service::advisory::osv::loader::OsvLoader; use hex::ToHex; use osv::schema::Vulnerability; use rstest::rstest; use test_context::test_context; use test_log::test; - - use crate::graph::Graph; - use trustify_common::db::Transactional; use trustify_test_context::{document, TrustifyContext}; - use crate::service::advisory::osv::loader::OsvLoader; - - use super::*; - #[test_context(TrustifyContext)] #[test(tokio::test)] async fn loader(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { @@ -412,13 +418,11 @@ mod test { let (osv, digests): (Vulnerability, _) = document("osv/RUSTSEC-2021-0079.json").await?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2021-32714", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2021-32714", &ctx.db).await?; assert!(loaded_vulnerability.is_none()); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_none()); @@ -427,13 +431,11 @@ mod test { .load(("file", "RUSTSEC-2021-0079.json"), osv, &digests, None) .await?; - let loaded_vulnerability = graph - .get_vulnerability("CVE-2021-32714", Transactional::None) - .await?; + let loaded_vulnerability = graph.get_vulnerability("CVE-2021-32714", &ctx.db).await?; assert!(loaded_vulnerability.is_some()); let loaded_advisory = graph - .get_advisory_by_digest(&digests.sha256.encode_hex::(), Transactional::None) + .get_advisory_by_digest(&digests.sha256.encode_hex::(), &ctx.db) .await?; assert!(loaded_advisory.is_some()); @@ -441,17 +443,17 @@ mod test { assert!(loaded_advisory.advisory.issuer_id.is_some()); - let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(()).await?; + let loaded_advisory_vulnerabilities = loaded_advisory.vulnerabilities(&ctx.db).await?; assert_eq!(1, loaded_advisory_vulnerabilities.len()); let _loaded_advisory_vulnerability = &loaded_advisory_vulnerabilities[0]; let advisory_vuln = loaded_advisory - .get_vulnerability("CVE-2021-32714", ()) + .get_vulnerability("CVE-2021-32714", &ctx.db) .await?; assert!(advisory_vuln.is_some()); let advisory_vuln = advisory_vuln.unwrap(); - let scores = advisory_vuln.cvss3_scores(()).await?; + let scores = advisory_vuln.cvss3_scores(&ctx.db).await?; assert_eq!(1, scores.len()); let score = scores[0]; @@ -461,7 +463,7 @@ mod test { ); assert!(loaded_advisory - .get_vulnerability("CVE-8675309", ()) + .get_vulnerability("CVE-8675309", &ctx.db) .await? .is_none()); diff --git a/modules/ingestor/src/service/mod.rs b/modules/ingestor/src/service/mod.rs index 8170007bc..a749a0166 100644 --- a/modules/ingestor/src/service/mod.rs +++ b/modules/ingestor/src/service/mod.rs @@ -198,21 +198,24 @@ impl IngestorService { match fmt { Format::SPDX | Format::CycloneDX => { - let analysis_service = AnalysisService::new(self.graph.db.clone()); + let analysis_service = AnalysisService::new(); if result.id.to_string().starts_with("urn:uuid:") { match analysis_service // TODO: today we chop off 'urn:uuid:' prefix using .split_off on result.id - .load_graphs(vec![result.id.to_string().split_off("urn:uuid:".len())], ()) + .load_graphs( + vec![result.id.to_string().split_off("urn:uuid:".len())], + &self.graph.db, + ) .await { Ok(_) => log::debug!( - "Analysis graph for sbom: {} loaded successfully.", - result.id.value() - ), + "Analysis graph for sbom: {} loaded successfully.", + result.id.value() + ), Err(e) => log::warn!( - "Error loading sbom {} into analysis graph : {}", - result.id.value(), - e - ), + "Error loading sbom {} into analysis graph : {}", + result.id.value(), + e + ), } } } diff --git a/modules/ingestor/src/service/sbom/clearly_defined.rs b/modules/ingestor/src/service/sbom/clearly_defined.rs index 17025f46b..a6e09bbaf 100644 --- a/modules/ingestor/src/service/sbom/clearly_defined.rs +++ b/modules/ingestor/src/service/sbom/clearly_defined.rs @@ -2,7 +2,7 @@ use crate::{graph::sbom::SbomInformation, graph::Graph, model::IngestResult, ser use anyhow::anyhow; use hex::ToHex; use jsonpath_rust::JsonPath; -use sea_orm::EntityTrait; +use sea_orm::{EntityTrait, TransactionTrait}; use std::str::FromStr; use trustify_common::{ hashing::Digests, @@ -65,7 +65,7 @@ impl<'g> ClearlyDefinedLoader<'g> { }); if let Some(document_id) = document_id { - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let sbom = self .graph diff --git a/modules/ingestor/src/service/sbom/clearly_defined_curation.rs b/modules/ingestor/src/service/sbom/clearly_defined_curation.rs index b709c80e1..b0d259aeb 100644 --- a/modules/ingestor/src/service/sbom/clearly_defined_curation.rs +++ b/modules/ingestor/src/service/sbom/clearly_defined_curation.rs @@ -1,6 +1,7 @@ use crate::{ graph::sbom::clearly_defined::Curation, graph::Graph, model::IngestResult, service::Error, }; +use sea_orm::TransactionTrait; use tracing::instrument; use trustify_common::{hashing::Digests, id::Id}; use trustify_entity::labels::Labels; @@ -21,7 +22,7 @@ impl<'g> ClearlyDefinedCurationLoader<'g> { curation: Curation, digests: &Digests, ) -> Result { - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let sbom = self .graph diff --git a/modules/ingestor/src/service/sbom/cyclonedx.rs b/modules/ingestor/src/service/sbom/cyclonedx.rs index 230102a17..f29c937a7 100644 --- a/modules/ingestor/src/service/sbom/cyclonedx.rs +++ b/modules/ingestor/src/service/sbom/cyclonedx.rs @@ -4,6 +4,7 @@ use crate::{ service::Error, }; use cyclonedx_bom::prelude::Bom; +use sea_orm::TransactionTrait; use tracing::instrument; use trustify_common::{hashing::Digests, id::Id}; use trustify_entity::labels::Labels; @@ -32,7 +33,7 @@ impl<'g> CyclonedxLoader<'g> { sbom.serial_number, ); - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let document_id = sbom .serial_number diff --git a/modules/ingestor/src/service/sbom/spdx.rs b/modules/ingestor/src/service/sbom/spdx.rs index 7e52f9532..42b295e40 100644 --- a/modules/ingestor/src/service/sbom/spdx.rs +++ b/modules/ingestor/src/service/sbom/spdx.rs @@ -6,6 +6,7 @@ use crate::{ model::IngestResult, service::{Error, Warnings}, }; +use sea_orm::TransactionTrait; use serde_json::Value; use tracing::instrument; use trustify_common::{hashing::Digests, id::Id}; @@ -36,7 +37,7 @@ impl<'g> SpdxLoader<'g> { spdx.document_creation_information.document_name ); - let tx = self.graph.transaction().await?; + let tx = self.graph.db.begin().await?; let labels = labels.add("type", "spdx"); diff --git a/modules/ingestor/src/service/weakness/mod.rs b/modules/ingestor/src/service/weakness/mod.rs index aa1e70f3a..beceb66a4 100644 --- a/modules/ingestor/src/service/weakness/mod.rs +++ b/modules/ingestor/src/service/weakness/mod.rs @@ -134,6 +134,7 @@ impl<'d> CweCatalogLoader<'d> { .exec(&tx) .await?; } + tx.commit().await?; } } diff --git a/modules/ingestor/tests/db.rs b/modules/ingestor/tests/db.rs index 6b0a0af4f..b1a264418 100644 --- a/modules/ingestor/tests/db.rs +++ b/modules/ingestor/tests/db.rs @@ -140,7 +140,9 @@ async fn create_set(ctx: &TrustifyContext) -> anyhow::Result<()> { withdrawn: None, version: None, }; - graph.ingest_advisory(d, (), &digests, info, ()).await?; + graph + .ingest_advisory(d, (), &digests, info, &ctx.db) + .await?; } } diff --git a/modules/ingestor/tests/reingest/csaf.rs b/modules/ingestor/tests/reingest/csaf.rs index 4e2e37337..66687a1b0 100644 --- a/modules/ingestor/tests/reingest/csaf.rs +++ b/modules/ingestor/tests/reingest/csaf.rs @@ -36,31 +36,31 @@ async fn reingest(ctx: TrustifyContext) -> anyhow::Result<()> { }; let adv = ctx .graph - .get_advisory_by_id(id, ()) + .get_advisory_by_id(id, &ctx.db) .await? .expect("must be found"); - assert_eq!(adv.vulnerabilities(()).await?.len(), 1); + assert_eq!(adv.vulnerabilities(&ctx.db).await?.len(), 1); - let all = adv.vulnerabilities(&()).await?; + let all = adv.vulnerabilities(&ctx.db).await?; assert_eq!(all.len(), 1); assert_eq!( all[0].advisory_vulnerability.vulnerability_id, "CVE-2023-33201" ); - let all = ctx.graph.get_vulnerabilities(()).await?; + let all = ctx.graph.get_vulnerabilities(&ctx.db).await?; assert_eq!(all.len(), 1); let vuln = ctx .graph - .get_vulnerability("CVE-2023-33201", ()) + .get_vulnerability("CVE-2023-33201", &ctx.db) .await? .expect("Must be found"); assert_eq!(vuln.vulnerability.id, "CVE-2023-33201"); - let descriptions = vuln.descriptions("en", ()).await?; + let descriptions = vuln.descriptions("en", &ctx.db).await?; assert_eq!(descriptions.len(), 0); Ok(()) diff --git a/modules/ingestor/tests/reingest/cve.rs b/modules/ingestor/tests/reingest/cve.rs index 9c99854e2..abda56fa3 100644 --- a/modules/ingestor/tests/reingest/cve.rs +++ b/modules/ingestor/tests/reingest/cve.rs @@ -18,11 +18,11 @@ async fn reingest(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { }; let adv = ctx .graph - .get_advisory_by_id(id, ()) + .get_advisory_by_id(id, &ctx.db) .await? .expect("must be found"); - let mut adv_vulns = adv.vulnerabilities(()).await?; + let mut adv_vulns = adv.vulnerabilities(&ctx.db).await?; assert_eq!(adv_vulns.len(), 1); let adv_vuln = adv_vulns.pop().unwrap(); assert_eq!( @@ -30,16 +30,16 @@ async fn reingest(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { Some(datetime!(2021-05-12 0:00:00 UTC)) ); - let vulns = ctx.graph.get_vulnerabilities(()).await?; + let vulns = ctx.graph.get_vulnerabilities(&ctx.db).await?; assert_eq!(vulns.len(), 1); let vuln = ctx .graph - .get_vulnerability("CVE-2021-32714", ()) + .get_vulnerability("CVE-2021-32714", &ctx.db) .await? .expect("Must be found"); - let descriptions = vuln.descriptions("en", ()).await?; + let descriptions = vuln.descriptions("en", &ctx.db).await?; assert_eq!(descriptions.len(), 1); Ok(()) diff --git a/modules/ingestor/tests/reingest/osv.rs b/modules/ingestor/tests/reingest/osv.rs index c7f5c0244..7c1d2b4ac 100644 --- a/modules/ingestor/tests/reingest/osv.rs +++ b/modules/ingestor/tests/reingest/osv.rs @@ -104,31 +104,31 @@ async fn assert_common( }; let adv = ctx .graph - .get_advisory_by_id(id, ()) + .get_advisory_by_id(id, &ctx.db) .await? .expect("must be found"); - assert_eq!(adv.vulnerabilities(()).await?.len(), 1); + assert_eq!(adv.vulnerabilities(&ctx.db).await?.len(), 1); - let all = adv.vulnerabilities(&()).await?; + let all = adv.vulnerabilities(&ctx.db).await?; assert_eq!(all.len(), 1); assert_eq!( all[0].advisory_vulnerability.vulnerability_id, expected_vuln_id ); - let all = ctx.graph.get_vulnerabilities(()).await?; + let all = ctx.graph.get_vulnerabilities(&ctx.db).await?; assert_eq!(all.len(), 1); let vuln = ctx .graph - .get_vulnerability(expected_vuln_id, ()) + .get_vulnerability(expected_vuln_id, &ctx.db) .await? .expect("Must be found"); assert_eq!(vuln.vulnerability.id, expected_vuln_id); - let descriptions = vuln.descriptions("en", ()).await?; + let descriptions = vuln.descriptions("en", &ctx.db).await?; assert_eq!(descriptions.len(), 0); Ok(()) diff --git a/modules/ingestor/tests/reingest/spdx.rs b/modules/ingestor/tests/reingest/spdx.rs index 7577192f4..eda88774f 100644 --- a/modules/ingestor/tests/reingest/spdx.rs +++ b/modules/ingestor/tests/reingest/spdx.rs @@ -18,13 +18,13 @@ async fn reingest(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { }; let sbom = ctx .graph - .get_sbom_by_id(id, ()) + .get_sbom_by_id(id, &ctx.db) .await? .expect("must be found"); // check CPEs - let cpes = sbom.describes_cpe22s(()).await?; + let cpes = sbom.describes_cpe22s(&ctx.db).await?; assert_eq!( cpes.into_iter() .map(|cpe| CpeDto::from(cpe.cpe)) @@ -35,7 +35,7 @@ async fn reingest(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { // check purls - let purls = sbom.describes_purls(()).await?; + let purls = sbom.describes_purls(&ctx.db).await?; assert_eq!( purls .into_iter() @@ -46,22 +46,25 @@ async fn reingest(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { // get product - let product = sbom.get_product(()).await?.expect("must have a product"); + let product = sbom + .get_product(&ctx.db) + .await? + .expect("must have a product"); assert_eq!(product.product.product.name, "quarkus-bom"); - let products = ctx.graph.get_products(()).await?; + let products = ctx.graph.get_products(&ctx.db).await?; assert_eq!(products.len(), 1); // get orgs, expect one - let orgs = ctx.graph.get_organizations(()).await?; + let orgs = ctx.graph.get_organizations(&ctx.db).await?; assert_eq!(orgs.len(), 1); // get all sboms, expect one let sboms = ctx .graph - .locate_many_sboms(sbom::Entity::find(), ()) + .locate_many_sboms(sbom::Entity::find(), &ctx.db) .await?; assert_eq!(sboms.len(), 1);