use std::convert::TryFrom; use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Result}; use lru::LruCache; use tracing::trace; use uuid::Uuid; use crate::OperationContext; use upend_base::addressing::Address; use upend_base::constants::ATTR_LABEL; use upend_base::constants::{ATTR_IN, HIER_ROOT_ADDR, HIER_ROOT_INVARIANT}; use upend_base::entry::Entry; use upend_base::lang::{PatternQuery, Query, QueryComponent, QueryPart}; use super::UpEndConnection; #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct UNode(String); impl std::str::FromStr for UNode { type Err = anyhow::Error; fn from_str(string: &str) -> Result { if string.is_empty() { Err(anyhow!("UNode can not be empty.")) } else { Ok(Self(string.to_string())) } } } impl std::fmt::Display for UNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } #[derive(Debug, Clone, PartialEq)] pub struct UHierPath(pub Vec); impl std::str::FromStr for UHierPath { type Err = anyhow::Error; fn from_str(string: &str) -> Result { if string.is_empty() { Ok(UHierPath(vec![])) } else { let result: Result> = string .trim_end_matches('/') .split('/') .map(UNode::from_str) .collect(); Ok(UHierPath(result?)) } } } impl std::fmt::Display for UHierPath { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{}", self.0 .iter() .map(|node| node.to_string()) .collect::>() .join("/") ) } } pub fn list_roots(connection: &UpEndConnection) -> Result> { Ok(connection .query(Query::SingleQuery(QueryPart::Matches(PatternQuery { entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact(ATTR_IN.parse().unwrap()), value: QueryComponent::Exact((*HIER_ROOT_ADDR).clone().into()), })))? .into_iter() .map(|e| e.entity) .collect()) } lazy_static! { static ref FETCH_CREATE_LOCK: Mutex<()> = Mutex::new(()); } pub fn fetch_or_create_dir( connection: &UpEndConnection, parent: Option
, directory: UNode, create: bool, context: OperationContext, ) -> Result
{ match parent.clone() { Some(address) => trace!("FETCHING/CREATING {}/{:#}", address, directory), None => trace!("FETCHING/CREATING /{:#}", directory), } let _lock; if create { _lock = FETCH_CREATE_LOCK.lock().unwrap(); } let matching_directories = connection .query(Query::SingleQuery(QueryPart::Matches(PatternQuery { entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact(ATTR_LABEL.parse().unwrap()), value: QueryComponent::Exact(directory.to_string().into()), })))? .into_iter() .map(|e: Entry| e.entity); let parent_has: Vec
= match parent.clone() { Some(parent) => connection .query(Query::SingleQuery(QueryPart::Matches(PatternQuery { entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact(ATTR_IN.parse().unwrap()), value: QueryComponent::Exact(parent.into()), })))? .into_iter() .map(|e| e.entity) .collect(), None => list_roots(connection)?, }; let valid_directories: Vec
= matching_directories .filter(|a| parent_has.contains(a)) .collect(); match valid_directories.len() { 0 => { if create { let new_directory_address = Address::Uuid(Uuid::new_v4()); let directory_entry = Entry { entity: new_directory_address.clone(), attribute: ATTR_LABEL.parse().unwrap(), value: directory.to_string().into(), provenance: context.provenance.clone() + "HIER", user: context.user.clone(), timestamp: chrono::Utc::now().naive_utc(), }; connection.insert_entry(directory_entry)?; connection.insert_entry(if let Some(parent) = parent { Entry { entity: new_directory_address.clone(), attribute: ATTR_IN.parse().unwrap(), value: parent.into(), provenance: context.provenance.clone() + "HIER", user: context.user.clone(), timestamp: chrono::Utc::now().naive_utc(), } } else { Entry { entity: new_directory_address.clone(), attribute: ATTR_IN.parse().unwrap(), value: HIER_ROOT_ADDR.clone().into(), provenance: context.provenance.clone() + "HIER", user: context.user.clone(), timestamp: chrono::Utc::now().naive_utc(), } })?; Ok(new_directory_address) } else { Err(anyhow!("Node {:?} does not exist.", directory.0)) } } 1 => Ok(valid_directories[0].clone()), _ => Err(anyhow!(format!( "Invalid database state - more than one directory matches the query {:?}/{:#}!", parent, directory ))), } } pub fn resolve_path( connection: &UpEndConnection, path: &UHierPath, create: bool, context: OperationContext, ) -> Result> { let mut result: Vec
= vec![]; let mut path_stack = path.0.to_vec(); path_stack.reverse(); while !path_stack.is_empty() { let dir_address = fetch_or_create_dir( connection, result.last().cloned(), path_stack.pop().unwrap(), create, context.clone(), )?; result.push(dir_address); } Ok(result) } pub type ResolveCache = LruCache<(Option
, UNode), Address>; pub fn resolve_path_cached( connection: &UpEndConnection, path: &UHierPath, create: bool, context: OperationContext, cache: &Arc>, ) -> Result> { let mut result: Vec
= vec![]; let mut path_stack = path.0.to_vec(); path_stack.reverse(); while let Some(node) = path_stack.pop() { let parent = result.last().cloned(); let key = (parent.clone(), node.clone()); let mut cache_lock = cache.lock().unwrap(); let cached_address = cache_lock.get(&key); if let Some(address) = cached_address { result.push(address.clone()); } else { drop(cache_lock); let address = fetch_or_create_dir(connection, parent, node, create, context.clone())?; result.push(address.clone()); cache.lock().unwrap().put(key, address); } } Ok(result) } pub fn initialize_hier(connection: &UpEndConnection) -> Result<()> { connection.insert_entry(Entry::try_from(&*HIER_ROOT_INVARIANT)?)?; upend_insert_val!(connection, HIER_ROOT_ADDR, ATTR_LABEL, "Hierarchy Root")?; Ok(()) } #[cfg(test)] mod tests { use anyhow::Result; use crate::UpEndDatabase; use tempfile::TempDir; use super::*; #[test] fn test_unode_nonempty() { let node = "foobar".parse::(); assert!(node.is_ok()); let node = "".parse::(); assert!(node.is_err()); } #[test] fn test_path_codec() { let path = UHierPath(vec![ UNode("top".to_string()), UNode("foo".to_string()), UNode("bar".to_string()), UNode("baz".to_string()), ]); let str_path = path.to_string(); assert!(!str_path.is_empty()); let decoded_path: Result = str_path.parse(); assert!(decoded_path.is_ok()); assert_eq!(path, decoded_path.unwrap()); } #[test] fn test_path_validation() { let valid_path: Result = "a/b/c/d/e/f/g".parse(); assert!(valid_path.is_ok()); let invalid_path: Result = "a/b/c//d/e/f/g".parse(); assert!(invalid_path.is_err()); let invalid_path: Result = "a//b/c//d/e/f///g".parse(); assert!(invalid_path.is_err()); } #[test] fn test_path_manipulation() { // Initialize database let temp_dir = TempDir::new().unwrap(); let open_result = UpEndDatabase::open(&temp_dir, true).unwrap(); let connection = open_result.db.connection().unwrap(); let foo_result = fetch_or_create_dir( &connection, None, UNode("foo".to_string()), true, OperationContext::default(), ); assert!(foo_result.is_ok()); let foo_result = foo_result.unwrap(); let bar_result = fetch_or_create_dir( &connection, None, UNode("bar".to_string()), true, OperationContext::default(), ); assert!(bar_result.is_ok()); let bar_result = bar_result.unwrap(); let baz_result = fetch_or_create_dir( &connection, Some(bar_result.clone()), UNode("baz".to_string()), true, OperationContext::default(), ); assert!(baz_result.is_ok()); let baz_result = baz_result.unwrap(); let roots = list_roots(&connection); assert_eq!(roots.unwrap(), [foo_result, bar_result.clone()]); let resolve_result = resolve_path( &connection, &"bar/baz".parse().unwrap(), false, OperationContext::default(), ); assert!(resolve_result.is_ok()); assert_eq!( resolve_result.unwrap(), vec![bar_result.clone(), baz_result.clone()] ); let resolve_result = resolve_path( &connection, &"bar/baz/bax".parse().unwrap(), false, OperationContext::default(), ); assert!(resolve_result.is_err()); let resolve_result = resolve_path( &connection, &"bar/baz/bax".parse().unwrap(), true, OperationContext::default(), ); assert!(resolve_result.is_ok()); let bax_result = fetch_or_create_dir( &connection, Some(baz_result.clone()), UNode("bax".to_string()), false, OperationContext::default(), ); assert!(bax_result.is_ok()); let bax_result = bax_result.unwrap(); assert_eq!( resolve_result.unwrap(), vec![bar_result, baz_result, bax_result] ); } }