From ce9e552844482e4230f995c1ed7ce3fdefe3458b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Ml=C3=A1dek?= Date: Thu, 23 Dec 2021 11:10:16 +0100 Subject: [PATCH] rewrite database module as a struct instead of bare fns --- src/database/hierarchies.rs | 66 +++-- src/database/macros.rs | 6 +- src/database/mod.rs | 483 ++++++++++++++++++------------------ src/filesystem.rs | 116 ++++----- src/main.rs | 11 +- src/routes.rs | 84 ++++--- 6 files changed, 369 insertions(+), 397 deletions(-) diff --git a/src/database/hierarchies.rs b/src/database/hierarchies.rs index 9c82390..6e25e2c 100644 --- a/src/database/hierarchies.rs +++ b/src/database/hierarchies.rs @@ -2,8 +2,6 @@ use std::convert::TryFrom; use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; -use diesel::sqlite::Sqlite; -use diesel::Connection; use log::trace; use lru::LruCache; use serde_json::Value; @@ -15,7 +13,8 @@ use crate::database::constants::{ }; use crate::database::entry::{Entry, EntryValue}; use crate::database::lang::{EntryQuery, Query, QueryComponent, QueryPart}; -use crate::database::{insert_entry, query, DbPool}; + +use super::UpEndConnection; #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct UNode(String); @@ -95,9 +94,8 @@ impl PointerEntries for Vec { } } -pub fn list_roots>(connection: &C) -> Result> { - let all_directories: Vec = query( - connection, +pub fn list_roots(connection: &UpEndConnection) -> Result> { + let all_directories: Vec = connection.query( Query::SingleQuery(QueryPart::Matches(EntryQuery { entity: QueryComponent::Any, attribute: QueryComponent::Exact(IS_OF_TYPE_ATTR.to_string()), @@ -106,8 +104,7 @@ pub fn list_roots>(connection: &C) -> Result = query( - connection, + let directories_with_parents: Vec
= connection.query( Query::SingleQuery(QueryPart::Matches(EntryQuery { entity: QueryComponent::Any, attribute: QueryComponent::Exact(HIER_HAS_ATTR.to_string()), @@ -130,8 +127,8 @@ lazy_static! { static ref FETCH_CREATE_LOCK: Mutex<()> = Mutex::new(()); } -pub fn fetch_or_create_dir>( - connection: &C, +pub fn fetch_or_create_dir( + connection: &UpEndConnection, parent: Option
, directory: UNode, create: bool, @@ -146,8 +143,7 @@ pub fn fetch_or_create_dir>( _lock = FETCH_CREATE_LOCK.lock().unwrap(); } - let matching_directories = query( - connection, + let matching_directories = connection.query( Query::SingleQuery(QueryPart::Matches(EntryQuery { entity: QueryComponent::Any, attribute: QueryComponent::Exact(String::from(LABEL_ATTR)), @@ -160,8 +156,7 @@ pub fn fetch_or_create_dir>( .map(|e: Entry| e.entity); let parent_has: Vec
= match parent.clone() { - Some(parent) => query( - connection, + Some(parent) => connection.query( Query::SingleQuery(QueryPart::Matches(EntryQuery { entity: QueryComponent::Exact(parent), attribute: QueryComponent::Exact(String::from(HIER_HAS_ATTR)), @@ -188,14 +183,14 @@ pub fn fetch_or_create_dir>( attribute: String::from(IS_OF_TYPE_ATTR), value: EntryValue::Address(HIER_ADDR.clone()), }; - insert_entry(connection, type_entry)?; + connection.insert_entry(type_entry)?; let directory_entry = Entry { entity: new_directory_address.clone(), attribute: String::from(LABEL_ATTR), value: EntryValue::Value(Value::String(directory.as_ref().clone())), }; - insert_entry(connection, directory_entry)?; + connection.insert_entry(directory_entry)?; if let Some(parent) = parent { let has_entry = Entry { @@ -203,7 +198,7 @@ pub fn fetch_or_create_dir>( attribute: String::from(HIER_HAS_ATTR), value: EntryValue::Address(new_directory_address.clone()), }; - insert_entry(connection, has_entry)?; + connection.insert_entry(has_entry)?; } Ok(new_directory_address) @@ -219,8 +214,8 @@ pub fn fetch_or_create_dir>( } } -pub fn resolve_path>( - connection: &C, +pub fn resolve_path( + connection: &UpEndConnection, path: &UHierPath, create: bool, ) -> Result> { @@ -243,8 +238,8 @@ pub fn resolve_path>( pub type ResolveCache = LruCache<(Option
, UNode), Address>; -pub fn resolve_path_cached>( - connection: &C, +pub fn resolve_path_cached( + connection: &UpEndConnection, path: &UHierPath, create: bool, cache: &Arc>, @@ -272,10 +267,10 @@ pub fn resolve_path_cached>( Ok(result) } -pub fn initialize_hier(pool: &DbPool) -> Result<()> { - insert_entry(&pool.get()?, Entry::try_from(&*HIER_INVARIANT)?)?; - upend_insert_addr!(&pool.get()?, HIER_ADDR, IS_OF_TYPE_ATTR, TYPE_ADDR); - upend_insert_val!(&pool.get()?, HIER_ADDR, TYPE_HAS_ATTR, HIER_HAS_ATTR); +pub fn initialize_hier(connection: &UpEndConnection) -> Result<()> { + connection.insert_entry(Entry::try_from(&*HIER_INVARIANT)?)?; + upend_insert_addr!(&connection, HIER_ADDR, IS_OF_TYPE_ATTR, TYPE_ADDR); + upend_insert_val!(&connection, HIER_ADDR, TYPE_HAS_ATTR, HIER_HAS_ATTR); Ok(()) } @@ -283,7 +278,7 @@ pub fn initialize_hier(pool: &DbPool) -> Result<()> { mod tests { use anyhow::Result; - use crate::database::open_upend; + use crate::database::UpEndDatabase; use tempdir::TempDir; use super::*; @@ -331,10 +326,11 @@ mod tests { fn test_path_manipulation() { // Initialize database let temp_dir = TempDir::new("upend-test").unwrap(); - let open_result = open_upend(&temp_dir, None, true).unwrap(); + let open_result = UpEndDatabase::open(&temp_dir, None, true).unwrap(); + let connection = open_result.db.connection().unwrap(); let foo_result = fetch_or_create_dir( - &open_result.pool.get().unwrap(), + &connection, None, UNode("foo".to_string()), true, @@ -343,7 +339,7 @@ mod tests { let foo_result = foo_result.unwrap(); let bar_result = fetch_or_create_dir( - &open_result.pool.get().unwrap(), + &connection, None, UNode("bar".to_string()), true, @@ -352,7 +348,7 @@ mod tests { let bar_result = bar_result.unwrap(); let baz_result = fetch_or_create_dir( - &open_result.pool.get().unwrap(), + &connection, Some(bar_result.clone()), UNode("baz".to_string()), true, @@ -360,11 +356,11 @@ mod tests { assert!(baz_result.is_ok()); let baz_result = baz_result.unwrap(); - let roots = list_roots(&open_result.pool.get().unwrap()); + let roots = list_roots(&connection); assert_eq!(roots.unwrap(), [foo_result, bar_result.clone()]); let resolve_result = resolve_path( - &open_result.pool.get().unwrap(), + &connection, &"bar/baz".parse().unwrap(), false, ); @@ -376,21 +372,21 @@ mod tests { ); let resolve_result = resolve_path( - &open_result.pool.get().unwrap(), + &connection, &"bar/baz/bax".parse().unwrap(), false, ); assert!(resolve_result.is_err()); let resolve_result = resolve_path( - &open_result.pool.get().unwrap(), + &connection, &"bar/baz/bax".parse().unwrap(), true, ); assert!(resolve_result.is_ok()); let bax_result = fetch_or_create_dir( - &open_result.pool.get().unwrap(), + &connection, Some(baz_result.clone()), UNode("bax".to_string()), false, diff --git a/src/database/macros.rs b/src/database/macros.rs index 103dffa..e49cc3f 100644 --- a/src/database/macros.rs +++ b/src/database/macros.rs @@ -1,7 +1,6 @@ macro_rules! upend_insert_val { ($db_connection:expr, $entity:expr, $attribute:expr, $value:expr) => {{ - insert_entry( - $db_connection, + $db_connection.insert_entry( Entry { entity: $entity.clone(), attribute: String::from($attribute), @@ -13,8 +12,7 @@ macro_rules! upend_insert_val { macro_rules! upend_insert_addr { ($db_connection:expr, $entity:expr, $attribute:expr, $addr:expr) => {{ - insert_entry( - $db_connection, + $db_connection.insert_entry( Entry { entity: $entity.clone(), attribute: String::from($attribute), diff --git a/src/database/mod.rs b/src/database/mod.rs index a49c1c5..e3b0206 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -21,9 +21,9 @@ use anyhow::{anyhow, Result}; use chrono::NaiveDateTime; use diesel::debug_query; use diesel::prelude::*; -use diesel::r2d2::{self, ConnectionManager}; +use diesel::r2d2::{self, ConnectionManager, PooledConnection}; use diesel::result::{DatabaseErrorKind, Error}; -use diesel::sqlite::{Sqlite, SqliteConnection}; +use diesel::sqlite::SqliteConnection; use hierarchies::initialize_hier; use log::{debug, trace}; use std::convert::TryFrom; @@ -31,184 +31,6 @@ use std::fs; use std::path::{Path, PathBuf}; use std::time::Duration; -pub fn insert_file>( - connection: &C, - file: models::NewFile, -) -> Result { - use crate::database::inner::schema::files; - - debug!( - "Inserting {} ({})...", - &file.path, - Address::Hash(Hash((&file.hash).clone())) - ); - - Ok(diesel::insert_into(files::table) - .values(file) - .execute(connection)?) -} - -pub fn retrieve_file>( - connection: &C, - obj_hash: Hash, -) -> Result> { - use crate::database::inner::schema::files::dsl::*; - - let matches = files - .filter(valid.eq(true)) - .filter(hash.eq(obj_hash.0)) - .load::(connection)?; - - Ok(matches) -} - -pub fn retrieve_all_files>( - connection: &C, -) -> Result> { - use crate::database::inner::schema::files::dsl::*; - let matches = files.load::(connection)?; - Ok(matches) -} - -pub fn get_latest_files>( - connection: &C, - count: i64, -) -> Result> { - use crate::database::inner::schema::files::dsl::*; - - let matches = files - .order_by(added.desc()) - .limit(count) - .load::(connection)?; - - Ok(matches) -} - -pub fn file_update_mtime>( - connection: &C, - file_id: i32, - m_time: Option, -) -> Result { - use crate::database::inner::schema::files::dsl::*; - - debug!("Setting file ID {}'s mtime = {:?}", file_id, m_time); - - Ok(diesel::update(files.filter(id.eq(file_id))) - .set(mtime.eq(m_time)) - .execute(connection)?) -} - -pub fn file_set_valid>( - connection: &C, - file_id: i32, - is_valid: bool, -) -> Result { - use crate::database::inner::schema::files::dsl::*; - - debug!("Setting file ID {} to valid = {}", file_id, is_valid); - - Ok(diesel::update(files.filter(id.eq(file_id))) - .set(valid.eq(is_valid)) - .execute(connection)?) -} - -pub fn retrieve_object>( - connection: &C, - object_address: Address, -) -> Result> { - use crate::database::inner::schema::data::dsl::*; - - let primary = data - .filter(entity.eq(object_address.encode()?)) - .or_filter(value.eq(EntryValue::Address(object_address).to_string()?)) - .load::(connection)?; - - let entries = primary - .iter() - .map(Entry::try_from) - .collect::>>()?; - - let secondary = data - .filter( - entity.eq_any( - entries - .iter() - .map(|e| e.address()) - .filter_map(Result::ok) - .map(|addr| addr.encode()) - .collect::>>>()?, - ), - ) - .load::(connection)?; - - let secondary_entries = secondary - .iter() - .map(Entry::try_from) - .collect::>>()?; - - Ok([entries, secondary_entries].concat()) -} - -pub fn remove_object>( - connection: &C, - object_address: Address, -) -> Result { - use crate::database::inner::schema::data::dsl::*; - - debug!("Deleting {}!", object_address); - - let matches = data - .filter(identity.eq(object_address.encode()?)) - .or_filter(entity.eq(object_address.encode()?)) - .or_filter(value.eq(EntryValue::Address(object_address).to_string()?)); - - Ok(diesel::delete(matches).execute(connection)?) -} - -pub fn query>(connection: &C, query: Query) -> Result> { - use crate::database::inner::schema::data::dsl::*; - - trace!("Querying: {:?}", query); - - let db_query = data.filter(query.to_sqlite_predicates()?); - - trace!("DB query: {}", debug_query(&db_query)); - - let matches = db_query.load::(connection)?; - - let entries = matches - .iter() - .map(Entry::try_from) - .filter_map(Result::ok) - .collect(); - - Ok(entries) -} - -pub fn insert_entry>( - connection: &C, - entry: Entry, -) -> Result
{ - debug!("Inserting: {}", entry); - - let insert_entry = models::Entry::try_from(&entry)?; - - let entry = Entry::try_from(&insert_entry)?; - - let result = diesel::insert_into(data::table) - .values(insert_entry) - .execute(connection); - - if let Some(error) = result.err() { - match error { - Error::DatabaseError(DatabaseErrorKind::UniqueViolation, _) => {} - _ => return Err(anyhow!(error)), - } - } - - Ok(Address::Hash(entry.hash()?)) -} - #[derive(Debug)] pub struct ConnectionOptions { pub enable_foreign_keys: bool, @@ -242,78 +64,260 @@ impl diesel::r2d2::CustomizeConnection } } -pub type DbPool = r2d2::Pool>; +type DbPool = r2d2::Pool>; pub struct OpenResult { - pub pool: DbPool, + pub db: UpEndDatabase, pub new: bool, } +pub struct UpEndDatabase { + pool: DbPool, + pub vault_path: PathBuf, +} + pub const UPEND_SUBDIR: &str = ".upend"; pub const DATABASE_FILENAME: &str = "upend.sqlite3"; -pub fn open_upend>( - dirpath: P, - db_path: Option, - reinitialize: bool, -) -> Result { - embed_migrations!("./migrations/upend/"); +impl UpEndDatabase { + pub fn open>( + dirpath: P, + db_path: Option, + reinitialize: bool, + ) -> Result { + embed_migrations!("./migrations/upend/"); - let upend_path = db_path.unwrap_or_else(|| dirpath.as_ref().join(UPEND_SUBDIR)); + let upend_path = db_path.unwrap_or_else(|| dirpath.as_ref().join(UPEND_SUBDIR)); - if reinitialize { - trace!("Reinitializing - removing previous database..."); - let _ = fs::remove_dir_all(&upend_path); - } - let new = !upend_path.exists(); + if reinitialize { + trace!("Reinitializing - removing previous database..."); + let _ = fs::remove_dir_all(&upend_path); + } + let new = !upend_path.exists(); - if new { - trace!("Creating UpEnd subdirectory..."); - fs::create_dir(&upend_path)?; + if new { + trace!("Creating UpEnd subdirectory..."); + fs::create_dir(&upend_path)?; + } + + trace!("Creating pool."); + let manager = ConnectionManager::::new( + upend_path.join(DATABASE_FILENAME).to_str().unwrap(), + ); + let pool = r2d2::Pool::builder() + .connection_customizer(Box::new(ConnectionOptions { + enable_foreign_keys: true, + busy_timeout: Some(Duration::from_secs(30)), + })) + .build(manager)?; + + let db = UpEndDatabase { + pool, + vault_path: PathBuf::from(dirpath.as_ref()), + }; + let connection = db.connection().unwrap(); + + let enable_wal_mode = true; + connection.execute(if enable_wal_mode { + trace!("Enabling WAL journal mode & truncating WAL log..."); + "PRAGMA journal_mode = WAL;PRAGMA wal_checkpoint(TRUNCATE);" + } else { + trace!("Enabling TRUNCATE journal mode"); + "PRAGMA journal_mode = TRUNCATE;" + })?; + + trace!("Pool created, running migrations..."); + + embedded_migrations::run_with_output( + &db.pool.get()?, + &mut LoggerSink { + ..Default::default() + }, + )?; + + trace!("Initializing types..."); + connection.insert_entry(Entry::try_from(&*TYPE_INVARIANT)?)?; + upend_insert_addr!(&connection, TYPE_ADDR, IS_OF_TYPE_ATTR, TYPE_ADDR); + upend_insert_val!(&connection, TYPE_ADDR, TYPE_HAS_ATTR, TYPE_HAS_ATTR); + + initialize_hier(&connection)?; + + Ok(OpenResult { db, new }) } - trace!("Creating pool."); - let manager = ConnectionManager::::new( - upend_path.join(DATABASE_FILENAME).to_str().unwrap(), - ); - let pool = r2d2::Pool::builder() - .connection_customizer(Box::new(ConnectionOptions { - enable_foreign_keys: true, - busy_timeout: Some(Duration::from_secs(30)), - })) - .build(manager)?; - - let enable_wal_mode = true; - pool.get().unwrap().execute(if enable_wal_mode { - trace!("Enabling WAL journal mode & truncating WAL log..."); - "PRAGMA journal_mode = WAL;PRAGMA wal_checkpoint(TRUNCATE);" - } else { - trace!("Enabling TRUNCATE journal mode"); - "PRAGMA journal_mode = TRUNCATE;" - })?; - - trace!("Pool created, running migrations..."); - - embedded_migrations::run_with_output( - &pool.get()?, - &mut LoggerSink { - ..Default::default() - }, - )?; - - trace!("Initializing types..."); - - initialize_types(&pool)?; - initialize_hier(&pool)?; - - Ok(OpenResult { pool, new }) + pub fn connection(self: &Self) -> Result { + Ok(UpEndConnection(self.pool.get()?)) + } } -fn initialize_types(pool: &DbPool) -> Result<()> { - insert_entry(&pool.get()?, Entry::try_from(&*TYPE_INVARIANT)?)?; - upend_insert_addr!(&pool.get()?, TYPE_ADDR, IS_OF_TYPE_ATTR, TYPE_ADDR); - upend_insert_val!(&pool.get()?, TYPE_ADDR, TYPE_HAS_ATTR, TYPE_HAS_ATTR); - Ok(()) +pub struct UpEndConnection(PooledConnection>); + +impl UpEndConnection { + pub fn execute>(self: &Self, query: S) -> Result { + self.0.execute(query.as_ref()) + } + + pub fn transaction(&self, f: F) -> Result + where + F: FnOnce() -> Result, + E: From, + { + self.0.transaction(f) + } + + pub fn insert_file(self: &Self, file: models::NewFile) -> Result { + use crate::database::inner::schema::files; + + debug!( + "Inserting {} ({})...", + &file.path, + Address::Hash(Hash((&file.hash).clone())) + ); + + Ok(diesel::insert_into(files::table) + .values(file) + .execute(&self.0)?) + } + + pub fn retrieve_file(self: &Self, obj_hash: Hash) -> Result> { + use crate::database::inner::schema::files::dsl::*; + + let matches = files + .filter(valid.eq(true)) + .filter(hash.eq(obj_hash.0)) + .load::(&self.0)?; + + Ok(matches) + } + + pub fn retrieve_all_files(self: &Self) -> Result> { + use crate::database::inner::schema::files::dsl::*; + let matches = files.load::(&self.0)?; + Ok(matches) + } + + pub fn get_latest_files(self: &Self, count: i64) -> Result> { + use crate::database::inner::schema::files::dsl::*; + + let matches = files + .order_by(added.desc()) + .limit(count) + .load::(&self.0)?; + + Ok(matches) + } + + pub fn file_update_mtime( + self: &Self, + file_id: i32, + m_time: Option, + ) -> Result { + use crate::database::inner::schema::files::dsl::*; + + debug!("Setting file ID {}'s mtime = {:?}", file_id, m_time); + + Ok(diesel::update(files.filter(id.eq(file_id))) + .set(mtime.eq(m_time)) + .execute(&self.0)?) + } + + pub fn file_set_valid(self: &Self, file_id: i32, is_valid: bool) -> Result { + use crate::database::inner::schema::files::dsl::*; + + debug!("Setting file ID {} to valid = {}", file_id, is_valid); + + Ok(diesel::update(files.filter(id.eq(file_id))) + .set(valid.eq(is_valid)) + .execute(&self.0)?) + } + + pub fn retrieve_object(self: &Self, object_address: Address) -> Result> { + use crate::database::inner::schema::data::dsl::*; + + let primary = data + .filter(entity.eq(object_address.encode()?)) + .or_filter(value.eq(EntryValue::Address(object_address).to_string()?)) + .load::(&self.0)?; + + let entries = primary + .iter() + .map(Entry::try_from) + .collect::>>()?; + + let secondary = data + .filter( + entity.eq_any( + entries + .iter() + .map(|e| e.address()) + .filter_map(Result::ok) + .map(|addr| addr.encode()) + .collect::>>>()?, + ), + ) + .load::(&self.0)?; + + let secondary_entries = secondary + .iter() + .map(Entry::try_from) + .collect::>>()?; + + Ok([entries, secondary_entries].concat()) + } + + pub fn remove_object(self: &Self, object_address: Address) -> Result { + use crate::database::inner::schema::data::dsl::*; + + debug!("Deleting {}!", object_address); + + let matches = data + .filter(identity.eq(object_address.encode()?)) + .or_filter(entity.eq(object_address.encode()?)) + .or_filter(value.eq(EntryValue::Address(object_address).to_string()?)); + + Ok(diesel::delete(matches).execute(&self.0)?) + } + + pub fn query(self: &Self, query: Query) -> Result> { + use crate::database::inner::schema::data::dsl::*; + + trace!("Querying: {:?}", query); + + let db_query = data.filter(query.to_sqlite_predicates()?); + + trace!("DB query: {}", debug_query(&db_query)); + + let matches = db_query.load::(&self.0)?; + + let entries = matches + .iter() + .map(Entry::try_from) + .filter_map(Result::ok) + .collect(); + + Ok(entries) + } + + pub fn insert_entry(self: &Self, entry: Entry) -> Result
{ + debug!("Inserting: {}", entry); + + let insert_entry = models::Entry::try_from(&entry)?; + + let entry = Entry::try_from(&insert_entry)?; + + let result = diesel::insert_into(data::table) + .values(insert_entry) + .execute(&self.0); + + if let Some(error) = result.err() { + match error { + Error::DatabaseError(DatabaseErrorKind::UniqueViolation, _) => {} + _ => return Err(anyhow!(error)), + } + } + + Ok(Address::Hash(entry.hash()?)) + } } #[cfg(test)] @@ -322,7 +326,8 @@ mod test { use tempdir::TempDir; #[test] - fn test_open() -> Result<(), anyhow::Error> { - open_upend(TempDir::new("upend-test").unwrap(), None, true).map(|_| ()) + fn test_open() { + let result = UpEndDatabase::open(TempDir::new("upend-test").unwrap(), None, false); + assert!(result.is_ok()); } } diff --git a/src/filesystem.rs b/src/filesystem.rs index ab9a614..a0bbc6e 100644 --- a/src/filesystem.rs +++ b/src/filesystem.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::convert::TryFrom; use std::path::{Component, Path, PathBuf}; use std::sync::{Arc, Mutex, RwLock}; @@ -9,17 +10,13 @@ use crate::database::constants::{ HIER_HAS_ATTR, IS_OF_TYPE_ATTR, TYPE_ADDR, TYPE_BASE_ATTR, TYPE_HAS_ATTR, }; use crate::database::entry::{Entry, EntryValue, InvariantEntry}; -use crate::database::hierarchies::{resolve_path_cached, ResolveCache, UNode, UHierPath}; +use crate::database::hierarchies::{resolve_path_cached, ResolveCache, UHierPath, UNode}; use crate::database::inner::models; -use crate::database::{ - file_set_valid, file_update_mtime, insert_entry, insert_file, retrieve_all_files, DbPool, - UPEND_SUBDIR, -}; +use crate::database::{UpEndConnection, UpEndDatabase, UPEND_SUBDIR}; use crate::util::hash::{Hash, Hashable}; use crate::util::jobs::{Job, JobContainer, JobId, State}; use anyhow::{Error, Result}; use chrono::prelude::*; -use diesel::Connection; use log::{debug, error, info, warn}; use lru::LruCache; use rayon::prelude::*; @@ -40,22 +37,18 @@ lazy_static! { static ref BLOB_TYPE_ADDR: Address = BLOB_TYPE_INVARIANT.entity().unwrap(); } -fn initialize_types(pool: &DbPool) -> Result<()> { +fn initialize_types(connection: &UpEndConnection) -> Result<()> { // BLOB_TYPE - insert_entry(&pool.get()?, Entry::try_from(&*BLOB_TYPE_INVARIANT)?)?; - upend_insert_addr!(&pool.get()?, BLOB_TYPE_ADDR, IS_OF_TYPE_ATTR, TYPE_ADDR); - upend_insert_val!(&pool.get()?, BLOB_TYPE_ADDR, TYPE_HAS_ATTR, FILE_MTIME_KEY); - upend_insert_val!(&pool.get()?, BLOB_TYPE_ADDR, TYPE_HAS_ATTR, FILE_SIZE_KEY); - upend_insert_val!(&pool.get()?, BLOB_TYPE_ADDR, TYPE_HAS_ATTR, FILE_MIME_KEY); + connection.insert_entry(Entry::try_from(&*BLOB_TYPE_INVARIANT)?)?; + upend_insert_addr!(&connection, BLOB_TYPE_ADDR, IS_OF_TYPE_ATTR, TYPE_ADDR); + upend_insert_val!(&connection, BLOB_TYPE_ADDR, TYPE_HAS_ATTR, FILE_MTIME_KEY); + upend_insert_val!(&connection, BLOB_TYPE_ADDR, TYPE_HAS_ATTR, FILE_SIZE_KEY); + upend_insert_val!(&connection, BLOB_TYPE_ADDR, TYPE_HAS_ATTR, FILE_MIME_KEY); Ok(()) } -pub async fn rescan_vault( - pool: DbPool, - directory: PathBuf, - job_container: Arc>, -) { +pub async fn rescan_vault(db: Arc, job_container: Arc>) { let job_id = job_container .write() .unwrap() @@ -64,8 +57,7 @@ pub async fn rescan_vault( let job_container_rescan = job_container.clone(); let result = - actix_web::web::block(move || _rescan_vault(pool, directory, job_container_rescan, job_id)) - .await; + actix_web::web::block(move || _rescan_vault(db, job_container_rescan, job_id)).await; if result.is_err() { let err = result.err().unwrap(); @@ -78,18 +70,12 @@ pub async fn rescan_vault( .unwrap(); } } -struct PragmaSynchronousGuard<'a>(&'a DbPool); +struct PragmaSynchronousGuard<'a>(&'a UpEndConnection); impl Drop for PragmaSynchronousGuard<'_> { fn drop(&mut self) { debug!("Re-enabling synchronous mode."); - let connection = self.0.get(); - let res: Result<_, String> = match connection { - Ok(connection) => connection - .execute("PRAGMA synchronous = NORMAL;") - .map_err(|err| format!("{}", err)), - Err(err) => Err(format!("{}", err)), - }; + let res = self.0.execute("PRAGMA synchronous = NORMAL;"); if let Err(err) = res { error!( "Error setting synchronous mode back to NORMAL! Data loss possible! {}", @@ -109,28 +95,30 @@ enum UpdatePathOutcome { Failed(PathBuf, Error), } -fn _rescan_vault>( - pool: DbPool, - directory: T, +fn _rescan_vault>( + db: D, job_container: Arc>, job_id: JobId, ) -> Result> { let start = Instant::now(); info!("Vault rescan started."); + let db = db.borrow(); + let connection = db.connection()?; + // Initialize types, etc... debug!("Initializing DB types."); - initialize_types(&pool)?; + initialize_types(&connection)?; // Disable syncing in SQLite for the duration of the import debug!("Disabling SQLite synchronous mode"); - pool.get()?.execute("PRAGMA synchronous = OFF;")?; - let _guard = PragmaSynchronousGuard(&pool); + connection.execute("PRAGMA synchronous = OFF;")?; + let _guard = PragmaSynchronousGuard(&connection); // Walk through the vault, find all paths debug!("Traversing vault directory"); - let absolute_dir_path = fs::canonicalize(&directory)?; - let path_entries: Vec = WalkDir::new(&directory) + let absolute_dir_path = fs::canonicalize(&db.vault_path)?; + let path_entries: Vec = WalkDir::new(&db.vault_path) .follow_links(true) .into_iter() .filter_map(|e| e.ok()) @@ -140,8 +128,7 @@ fn _rescan_vault>( .collect(); // Prepare for processing - let rw_pool = Arc::new(RwLock::new(pool.clone())); - let existing_files = Arc::new(RwLock::new(retrieve_all_files(&pool.get()?)?)); + let existing_files = Arc::new(RwLock::new(connection.retrieve_all_files()?)); // Actual processing let count = RwLock::new(0_usize); @@ -151,7 +138,7 @@ fn _rescan_vault>( .into_par_iter() .map(|path| { let result = _process_directory_entry( - &rw_pool, + db.connection().unwrap(), &resolve_cache, path.clone(), &absolute_dir_path, @@ -178,10 +165,9 @@ fn _rescan_vault>( let existing_files = existing_files.read().unwrap(); - let connection = pool.get()?; let cleanup_results = existing_files.iter().filter(|f| f.valid).map(|file| { let trans_result = connection.transaction::<_, Error, _>(|| { - file_set_valid(&connection, file.id, false)?; + connection.file_set_valid(file.id, false)?; // remove_object(&connection, )? Ok(()) }); @@ -225,8 +211,8 @@ fn _rescan_vault>( drop(_guard); info!( - "Finished updating {} ({} created, {} deleted, {} left unchanged). Took {}s.", - directory.as_ref().display(), + "Finished updating {:?} ({} created, {} deleted, {} left unchanged). Took {}s.", + db.vault_path, created, deleted, unchanged, @@ -237,7 +223,7 @@ fn _rescan_vault>( } fn _process_directory_entry>( - db_pool: &Arc>, + connection: UpEndConnection, resolve_cache: &Arc>, path: PathBuf, directory_path: &P, @@ -246,7 +232,6 @@ fn _process_directory_entry>( debug!("Processing: {:?}", path); // Prepare the data - let connection = &db_pool.write().unwrap().get()?; let existing_files = Arc::clone(existing_files); let normalized_path = path.strip_prefix(&directory_path)?; @@ -288,11 +273,11 @@ fn _process_directory_entry>( if same_mtime || same_hash { if mtime != existing_file.mtime { - file_update_mtime(connection, existing_file.id, mtime)?; + connection.file_update_mtime(existing_file.id, mtime)?; } if !existing_file.valid { - file_set_valid(connection, existing_file.id, true)?; + connection.file_set_valid(existing_file.id, true)?; } let mut existing_files_write = existing_files.write().unwrap(); @@ -357,22 +342,22 @@ fn _process_directory_entry>( })) .collect(), ); - let resolved_path = resolve_path_cached(connection, &upath, true, resolve_cache)?; + let resolved_path = resolve_path_cached(&connection, &upath, true, resolve_cache)?; let parent_dir = resolved_path.last().unwrap(); connection.transaction::<_, Error, _>(|| { - insert_file(connection, new_file)?; + connection.insert_file(new_file)?; - insert_entry(connection, type_entry)?; - insert_entry(connection, size_entry)?; - insert_entry(connection, mime_entry)?; + connection.insert_entry(type_entry)?; + connection.insert_entry(size_entry)?; + connection.insert_entry(mime_entry)?; let dir_has_entry = Entry { entity: parent_dir.clone(), attribute: HIER_HAS_ATTR.to_string(), value: EntryValue::Address(Address::Hash(file_hash.clone())), }; - let dir_has_entry_addr = insert_entry(connection, dir_has_entry)?; + let dir_has_entry_addr = connection.insert_entry(dir_has_entry)?; let name_entry = Entry { entity: dir_has_entry_addr, @@ -381,7 +366,7 @@ fn _process_directory_entry>( filename.as_os_str().to_string_lossy().to_string(), )), }; - insert_entry(connection, name_entry)?; + connection.insert_entry(name_entry)?; info!("Added: {:?}", path); Ok(UpdatePathOutcome::Added(path.clone())) @@ -390,7 +375,7 @@ fn _process_directory_entry>( #[cfg(test)] mod test { - use crate::database::open_upend; + use crate::database::UpEndDatabase; use crate::util; use super::*; @@ -416,7 +401,7 @@ mod test { // Initialize database - let open_result = open_upend(&temp_dir, None, true).unwrap(); + let open_result = UpEndDatabase::open(&temp_dir, None, true).unwrap(); let job_container = Arc::new(RwLock::new(util::jobs::JobContainer::default())); let job_id = job_container .write() @@ -425,12 +410,7 @@ mod test { .unwrap(); // Initial scan - let rescan_result = _rescan_vault( - open_result.pool.clone(), - temp_dir.as_ref().to_path_buf(), - job_container.clone(), - job_id, - ); + let rescan_result = _rescan_vault(&open_result.db, job_container.clone(), job_id); assert!(rescan_result.is_ok()); let rescan_result = rescan_result.unwrap(); @@ -441,12 +421,7 @@ mod test { // Modification-less rescan - let rescan_result = _rescan_vault( - open_result.pool.clone(), - temp_dir.as_ref().to_path_buf(), - job_container.clone(), - job_id, - ); + let rescan_result = _rescan_vault(&open_result.db, job_container.clone(), job_id); assert!(rescan_result.is_ok()); let rescan_result = rescan_result.unwrap(); @@ -459,12 +434,7 @@ mod test { std::fs::remove_file(temp_dir.path().join("hello-world.txt")).unwrap(); - let rescan_result = _rescan_vault( - open_result.pool, - temp_dir.as_ref().to_path_buf(), - job_container, - job_id, - ); + let rescan_result = _rescan_vault(&open_result.db, job_container, job_id); assert!(rescan_result.is_ok()); let rescan_result = rescan_result.unwrap(); diff --git a/src/main.rs b/src/main.rs index c3d5837..de1dedb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,8 @@ use clap::{App as ClapApp, Arg}; use log::{info, warn}; use std::sync::{Arc, RwLock}; +use crate::database::UpEndDatabase; + mod addressing; mod database; mod filesystem; @@ -85,14 +87,14 @@ fn main() -> Result<()> { let vault_path = PathBuf::from(matches.value_of("DIRECTORY").unwrap()); - let open_result = database::open_upend( + let open_result = UpEndDatabase::open( &vault_path, matches.value_of("DB_PATH").map(PathBuf::from), matches.is_present("REINITIALIZE"), ) .expect("failed to open database!"); - let db_pool = open_result.pool; + let upend = Arc::new(open_result.db); let bind: SocketAddr = matches .value_of("BIND") @@ -102,6 +104,7 @@ fn main() -> Result<()> { info!("Starting server at: {}", &bind); let state = routes::State { + upend: upend.clone(), vault_name: Some( matches .value_of("VAULT_NAME") @@ -115,8 +118,6 @@ fn main() -> Result<()> { .into_owned() }), ), - directory: vault_path.clone(), - db_pool: db_pool.clone(), job_container: job_container.clone(), }; @@ -154,7 +155,7 @@ fn main() -> Result<()> { if !matches.is_present("NO_INITIAL_UPDATE") { info!("Running initial update..."); - actix::spawn(filesystem::rescan_vault(db_pool, vault_path, job_container)); + actix::spawn(filesystem::rescan_vault(upend, job_container)); } #[cfg(feature = "desktop")] diff --git a/src/routes.rs b/src/routes.rs index df13a6c..0379b21 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -2,9 +2,7 @@ use crate::addressing::{Address, Addressable}; use crate::database::entry::{Entry, InEntry}; use crate::database::hierarchies::{list_roots, resolve_path, UHierPath}; use crate::database::lang::Query; -use crate::database::{ - get_latest_files, insert_entry, query, remove_object, retrieve_file, retrieve_object, DbPool, -}; +use crate::database::UpEndDatabase; use crate::util::hash::{decode, encode}; use crate::util::jobs::JobContainer; use actix_files::NamedFile; @@ -18,7 +16,6 @@ use serde::Deserialize; use serde_json::json; use std::collections::HashMap; use std::convert::TryFrom; -use std::path::PathBuf; use std::sync::{Arc, RwLock}; #[cfg(feature = "desktop")] @@ -28,9 +25,8 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); #[derive(Clone)] pub struct State { + pub upend: Arc, pub vault_name: Option, - pub directory: PathBuf, - pub db_pool: DbPool, pub job_container: Arc>, } @@ -48,10 +44,12 @@ pub async fn get_raw( let address = Address::decode(&decode(hash.into_inner()).map_err(ErrorInternalServerError)?) .map_err(ErrorInternalServerError)?; if let Address::Hash(hash) = address { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; - let files = retrieve_file(&connection, hash).map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; + let files = connection + .retrieve_file(hash) + .map_err(ErrorInternalServerError)?; if let Some(file) = files.get(0) { - let file_path = state.directory.join(&file.path); + let file_path = state.upend.vault_path.join(&file.path); if !query.native.is_some() { Ok(Either::A(NamedFile::open(file_path)?)) @@ -73,7 +71,7 @@ pub async fn get_raw( http::header::ACCESS_CONTROL_EXPOSE_HEADERS, http::header::WARNING.to_string(), ); - + file_path .parent() .ok_or_else(|| { @@ -111,10 +109,12 @@ pub async fn get_query( state: web::Data, web::Query(info): web::Query, ) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; let in_query: Query = info.query.as_str().parse().map_err(ErrorBadRequest)?; - let entries = query(&connection, in_query).map_err(ErrorInternalServerError)?; + let entries = connection + .query(in_query) + .map_err(ErrorInternalServerError)?; let mut result: HashMap = HashMap::new(); for entry in entries { result.insert( @@ -152,13 +152,13 @@ pub async fn get_object( state: web::Data, address_str: web::Path, ) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; - let result: Vec = retrieve_object( - &connection, - Address::decode(&decode(address_str.into_inner()).map_err(ErrorBadRequest)?) - .map_err(ErrorBadRequest)?, - ) - .map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; + let result: Vec = connection + .retrieve_object( + Address::decode(&decode(address_str.into_inner()).map_err(ErrorBadRequest)?) + .map_err(ErrorBadRequest)?, + ) + .map_err(ErrorInternalServerError)?; debug!("{:?}", result); Ok(HttpResponse::Ok().json(result.as_hash().map_err(ErrorInternalServerError)?)) @@ -171,7 +171,7 @@ pub async fn put_object( state: web::Data, mut payload: web::Payload, ) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; let mut body = web::BytesMut::new(); while let Some(chunk) = payload.next().await { @@ -186,8 +186,9 @@ pub async fn put_object( let in_entry = serde_json::from_slice::(&body).map_err(ErrorBadRequest)?; let entry = Entry::try_from(in_entry).map_err(ErrorInternalServerError)?; - let result_address = - insert_entry(&connection, entry.clone()).map_err(ErrorInternalServerError)?; + let result_address = connection + .insert_entry(entry.clone()) + .map_err(ErrorInternalServerError)?; Ok(HttpResponse::Ok().json( [( @@ -205,13 +206,13 @@ pub async fn delete_object( state: web::Data, address_str: web::Path, ) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; - let _ = remove_object( - &connection, - Address::decode(&decode(address_str.into_inner()).map_err(ErrorBadRequest)?) - .map_err(ErrorInternalServerError)?, - ) - .map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; + let _ = connection + .remove_object( + Address::decode(&decode(address_str.into_inner()).map_err(ErrorBadRequest)?) + .map_err(ErrorInternalServerError)?, + ) + .map_err(ErrorInternalServerError)?; Ok(HttpResponse::Ok().finish()) } @@ -221,7 +222,7 @@ pub async fn list_hier( state: web::Data, path: web::Path, ) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; if path.is_empty() { Ok(HttpResponse::MovedPermanently() .header("Location", "/api/hier_roots") @@ -243,12 +244,12 @@ pub async fn list_hier( #[get("/api/hier_roots")] pub async fn list_hier_roots(state: web::Data) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; let result = list_roots(&connection) .map_err(ErrorInternalServerError)? .into_iter() - .map(|root| retrieve_object(&connection, root)) + .map(|root| connection.retrieve_object(root)) .collect::>>>() .map_err(ErrorInternalServerError)? .concat(); @@ -258,11 +259,8 @@ pub async fn list_hier_roots(state: web::Data) -> Result) -> Result { - let _pool = state.db_pool.clone(); - let _directory = state.directory.clone(); actix::spawn(crate::filesystem::rescan_vault( - _pool, - _directory, + state.upend.clone(), state.job_container.clone(), )); Ok(HttpResponse::Ok().finish()) @@ -277,8 +275,10 @@ pub async fn get_file( .map_err(ErrorInternalServerError)?; if let Address::Hash(hash) = address { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; - let response = retrieve_file(&connection, hash).map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; + let response = connection + .retrieve_file(hash) + .map_err(ErrorInternalServerError)?; Ok(HttpResponse::Ok().json(response)) } else { @@ -288,8 +288,10 @@ pub async fn get_file( #[get("/api/files/latest")] pub async fn latest_files(state: web::Data) -> Result { - let connection = state.db_pool.get().map_err(ErrorInternalServerError)?; - let files = get_latest_files(&connection, 100).map_err(ErrorInternalServerError)?; + let connection = state.upend.connection().map_err(ErrorInternalServerError)?; + let files = connection + .get_latest_files(100) + .map_err(ErrorInternalServerError)?; Ok(HttpResponse::Ok().json(&files)) } @@ -303,7 +305,7 @@ pub async fn get_jobs(state: web::Data) -> Result { pub async fn get_info(state: web::Data) -> Result { Ok(HttpResponse::Ok().json(json!({ "name": state.vault_name, - "location": state.directory, + "location": state.upend.vault_path, "version": VERSION }))) }