From 6fdc3e2f4848b0cfd503f398d1d0ce1f6934c8aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Ml=C3=A1dek?= Date: Sat, 16 Apr 2022 00:55:09 +0200 Subject: [PATCH] add `join` queries to the language (fixes #3) --- src/database/engine.rs | 285 ++++++++++++++++++++++++++---------- src/database/hierarchies.rs | 10 +- src/database/lang.rs | 83 +++++++---- src/database/mod.rs | 12 ++ 4 files changed, 279 insertions(+), 111 deletions(-) diff --git a/src/database/engine.rs b/src/database/engine.rs index 6ac5c4b..3a43dc2 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -1,7 +1,10 @@ +use std::collections::HashMap; +use std::iter::zip; + use super::entry::EntryValue; use super::inner::models::Entry; use super::inner::schema::data; -use super::lang::{Query, QueryComponent, QueryPart, QueryQualifier}; +use super::lang::{PatternQuery, Query, QueryComponent, QueryPart, QueryQualifier}; use crate::database::inner::models; use crate::diesel::IntoSql; use crate::diesel::RunQueryDsl; @@ -11,12 +14,11 @@ use diesel::expression::grouped::Grouped; use diesel::expression::operators::{And, Not, Or}; use diesel::sql_types::Bool; use diesel::sqlite::Sqlite; -use diesel::{debug_query, BoxableExpression, QueryDsl}; use diesel::{ r2d2::{ConnectionManager, PooledConnection}, SqliteConnection, }; -use log::trace; +use diesel::{BoxableExpression, QueryDsl}; #[derive(Debug, Clone)] pub struct QueryExecutionError(String); @@ -32,20 +34,134 @@ impl std::error::Error for QueryExecutionError {} pub fn execute( connection: &PooledConnection>, query: Query, -) -> Result> { +) -> Result, QueryExecutionError> { use crate::database::inner::schema::data::dsl::*; - let db_query = data.filter(to_sqlite_predicates(query)?); - trace!("DB query: {}", debug_query(&db_query)); - db_query.load::(connection).map_err(anyhow::Error::from) + + if let Some(predicates) = to_sqlite_predicates(query.clone())? { + let db_query = data.filter(predicates); + db_query + .load::(connection) + .map_err(|e| QueryExecutionError(e.to_string())) + } else { + match query { + Query::SingleQuery(_) => Err(QueryExecutionError( + "Forced manual evaluation of an atomic query, this should never happen.".into(), + )), + Query::MultiQuery(mq) => match mq.qualifier { + QueryQualifier::Not => Err(QueryExecutionError( + "Stopped manual evaluation at NOT sub-query due to performance limits. Please \ + rework your query." + .into(), + )), + _ => { + let subquery_results = mq + .queries + .iter() + .map(|q| execute(connection, *q.clone())) + .collect::>, QueryExecutionError>>()?; + match mq.qualifier { + QueryQualifier::Not => unreachable!(), + QueryQualifier::And => Ok(subquery_results + .into_iter() + .reduce(|acc, cur| { + acc.into_iter() + .filter(|e| { + cur.iter().map(|e| &e.identity).any(|x| x == &e.identity) + }) + .collect() + }) + .unwrap()), // TODO + QueryQualifier::Or => Ok(subquery_results.into_iter().flatten().collect()), + QueryQualifier::Join => { + let pattern_queries = mq + .queries + .into_iter() + .map(|q| match *q { + Query::SingleQuery(QueryPart::Matches(pq)) => Some(pq), + _ => None, + }) + .collect::>>(); + + if let Some(pattern_queries) = pattern_queries { + let entries = zip(pattern_queries, subquery_results) + .into_iter() + .map(|(query, results)| { + results + .into_iter() + .map(|e| EntryWithVars::new(&query, e)) + .collect::>() + }); + + let joined = entries + .reduce(|acc, cur| { + acc.into_iter() + .filter(|tested_entry| { + tested_entry.vars.iter().any(|(k1, v1)| { + cur.iter().any(|other_entry| { + other_entry + .vars + .iter() + .any(|(k2, v2)| k1 == k2 && v1 == v2) + }) + }) + }) + .collect() + }) + .unwrap(); // TODO + + Ok(joined.into_iter().map(|ev| ev.entry).collect()) + } else { + Err(QueryExecutionError( + "Cannot join on non-atomic queries.".into(), + )) + } + } + } + } + }, + } + } } -type Predicate = dyn BoxableExpression; +struct EntryWithVars { + entry: Entry, + vars: HashMap, +} -fn to_sqlite_predicates(query: Query) -> Result, QueryExecutionError> { +impl EntryWithVars { + pub fn new(query: &PatternQuery, entry: Entry) -> Self { + let mut vars = HashMap::new(); + + if let QueryComponent::Variable(Some(var_name)) = &query.entity { + vars.insert( + var_name.clone(), + crate::util::hash::b58_encode(&entry.entity), + ); + } + + if let QueryComponent::Variable(Some(var_name)) = &query.attribute { + vars.insert(var_name.clone(), entry.attribute.clone()); + } + + if let QueryComponent::Variable(Some(var_name)) = &query.value { + if let Some(value_str) = &entry.value_str { + vars.insert(var_name.clone(), value_str.clone()); + } + } + + EntryWithVars { entry, vars } + } +} + +type SqlPredicate = dyn BoxableExpression; + +type SqlResult = Option>; + +fn to_sqlite_predicates(query: Query) -> Result { match query { Query::SingleQuery(qp) => match qp { QueryPart::Matches(eq) => { - let mut subqueries: Vec> = vec![]; + let mut subqueries: Vec> = vec![]; match &eq.entity { QueryComponent::Exact(q_entity) => { @@ -63,7 +179,7 @@ fn to_sqlite_predicates(query: Query) -> Result, QueryExecutionEr QueryComponent::Contains(q_entity) => subqueries.push(Box::new( data::entity_searchable.like(format!("%{}%", q_entity)), )), - QueryComponent::Any => {} + QueryComponent::Variable(_) => {} }; match &eq.attribute { @@ -75,7 +191,7 @@ fn to_sqlite_predicates(query: Query) -> Result, QueryExecutionEr )), QueryComponent::Contains(q_attribute) => subqueries .push(Box::new(data::attribute.like(format!("%{}%", q_attribute)))), - QueryComponent::Any => {} + QueryComponent::Variable(_) => {} }; match &eq.value { @@ -95,94 +211,111 @@ fn to_sqlite_predicates(query: Query) -> Result, QueryExecutionEr })?; match first { - EntryValue::Number(_) => subqueries.push(Box::new( - data::value_num.eq_any( - q_values - .iter() - .map(|v| { - if let EntryValue::Number(n) = v { - Ok(*n) - } else { - Err(QueryExecutionError(format!("IN queries must not combine numeric and string values! ({v} is not a number)"))) - } + EntryValue::Number(_) => subqueries.push(Box::new( + data::value_num.eq_any( + q_values + .iter() + .map(|v| { + if let EntryValue::Number(n) = v { + Ok(*n) + } else { + Err(QueryExecutionError(format!( + "IN queries must not combine numeric and \ + string values! ({v} is not a number)" + ))) + } + }) + .collect::, QueryExecutionError>>()?, + ), + )), + _ => subqueries.push(Box::new( + data::value_str.eq_any( + q_values + .iter() + .map(|v| { + if let EntryValue::Number(_) = v { + Err(QueryExecutionError(format!( + "IN queries must not combine numeric and \ + string values! (Found {v})" + ))) + } else { + v.to_string().map_err(|e| { + QueryExecutionError(format!( + "failed producing sql: {e}" + )) }) - .collect::, QueryExecutionError>>()?, - ), - )), - _ => subqueries.push(Box::new( - data::value_str.eq_any( - q_values - .iter() - .map(|v| { - if let EntryValue::Number(_) = v { - Err(QueryExecutionError(format!("IN queries must not combine numeric and string values! (Found {v})"))) - } else { - v.to_string().map_err(|e| QueryExecutionError(format!("failed producing sql: {e}"))) - } - }) - .collect::, QueryExecutionError>>()?, - ), - )), - } + } + }) + .collect::, QueryExecutionError>>()?, + ), + )), + } } QueryComponent::Contains(q_value) => { subqueries.push(Box::new(data::value_str.like(format!("S%{}%", q_value)))) } - QueryComponent::Any => {} + QueryComponent::Variable(_) => {} }; match subqueries.len() { - 0 => Ok(Box::new(true.into_sql::())), - 1 => Ok(subqueries.remove(0)), + 0 => Ok(Some(Box::new(true.into_sql::()))), + 1 => Ok(Some(subqueries.remove(0))), _ => { - let mut result: Box, Box>> = + let mut result: Box, Box>> = Box::new(And::new(subqueries.remove(0), subqueries.remove(0))); while !subqueries.is_empty() { result = Box::new(And::new(result, subqueries.remove(0))); } - Ok(Box::new(result)) + Ok(Some(Box::new(result))) } } } QueryPart::Type(_) => unimplemented!("Type queries are not yet implemented."), }, Query::MultiQuery(mq) => { - let subqueries: Result>, QueryExecutionError> = mq + let mq_result = mq .queries .into_iter() .map(|sq| to_sqlite_predicates(*sq)) - .collect(); - let mut subqueries: Vec> = subqueries?; - match subqueries.len() { - 0 => Ok(Box::new(true.into_sql::())), - 1 => { - if let QueryQualifier::Not = mq.qualifier { - Ok(Box::new(Not::new(subqueries.remove(0)))) - } else { - Ok(subqueries.remove(0)) + .collect::, QueryExecutionError>>()?; + + let mq_result: Option>> = mq_result.into_iter().collect(); + + if let Some(mut subqueries) = mq_result { + match subqueries.len() { + 0 => Ok(Some(Box::new(true.into_sql::()))), + 1 => { + if let QueryQualifier::Not = mq.qualifier { + Ok(Some(Box::new(Not::new(subqueries.remove(0))))) + } else { + Ok(Some(subqueries.remove(0))) + } } + _ => match mq.qualifier { + QueryQualifier::Join => Ok(None), + QueryQualifier::And => { + let mut result: Box, Box>> = + Box::new(And::new(subqueries.remove(0), subqueries.remove(0))); + while !subqueries.is_empty() { + result = Box::new(And::new(result, subqueries.remove(0))); + } + Ok(Some(Box::new(Grouped(result)))) + } + QueryQualifier::Or => { + let mut result = + Box::new(Or::new(subqueries.remove(0), subqueries.remove(0))); + while !subqueries.is_empty() { + result = Box::new(Or::new(result, subqueries.remove(0))); + } + Ok(Some(Box::new(Grouped(result)))) + } + QueryQualifier::Not => { + Err(QueryExecutionError("NOT only takes one subquery.".into())) + } + }, } - _ => match mq.qualifier { - QueryQualifier::And => { - let mut result: Box, Box>> = - Box::new(And::new(subqueries.remove(0), subqueries.remove(0))); - while !subqueries.is_empty() { - result = Box::new(And::new(result, subqueries.remove(0))); - } - Ok(Box::new(Grouped(result))) - } - QueryQualifier::Or => { - let mut result = - Box::new(Or::new(subqueries.remove(0), subqueries.remove(0))); - while !subqueries.is_empty() { - result = Box::new(Or::new(result, subqueries.remove(0))); - } - Ok(Box::new(Grouped(result))) - } - QueryQualifier::Not => { - Err(QueryExecutionError("NOT only takes one subquery.".into())) - } - }, + } else { + Ok(None) } } } diff --git a/src/database/hierarchies.rs b/src/database/hierarchies.rs index 896e8b9..433ac90 100644 --- a/src/database/hierarchies.rs +++ b/src/database/hierarchies.rs @@ -96,7 +96,7 @@ impl PointerEntries for Vec { pub fn list_roots(connection: &UpEndConnection) -> Result> { let all_directories: Vec = connection.query(Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, + entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact(IS_OF_TYPE_ATTR.into()), value: QueryComponent::Exact(HIER_ADDR.clone().into()), })))?; @@ -104,9 +104,9 @@ pub fn list_roots(connection: &UpEndConnection) -> Result> { // TODO: this is horrible let directories_with_parents: Vec
= connection .query(Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, + entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact(HIER_HAS_ATTR.into()), - value: QueryComponent::Any, + value: QueryComponent::Variable(None), })))? .extract_pointers() .into_iter() @@ -142,7 +142,7 @@ pub fn fetch_or_create_dir( let matching_directories = connection .query(Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, + entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact(LABEL_ATTR.into()), value: QueryComponent::Exact(directory.as_ref().clone().into()), })))? @@ -154,7 +154,7 @@ pub fn fetch_or_create_dir( .query(Query::SingleQuery(QueryPart::Matches(PatternQuery { entity: QueryComponent::Exact(parent), attribute: QueryComponent::Exact(HIER_HAS_ATTR.into()), - value: QueryComponent::Any, + value: QueryComponent::Variable(None), })))? .extract_pointers() .into_iter() diff --git a/src/database/lang.rs b/src/database/lang.rs index 226baf1..88d1985 100644 --- a/src/database/lang.rs +++ b/src/database/lang.rs @@ -22,7 +22,7 @@ where Exact(T), In(Vec), Contains(String), - Any, + Variable(Option), } #[derive(Debug, Clone, PartialEq)] @@ -99,6 +99,7 @@ pub enum QueryQualifier { And, Or, Not, + Join, } #[derive(Debug, Clone, PartialEq)] @@ -182,7 +183,14 @@ impl TryFrom<&lexpr::Value> for Query { ))) } } - lexpr::Value::Symbol(symbol) if symbol.as_ref() == "?" => Ok(QueryComponent::Any), + lexpr::Value::Symbol(symbol) if symbol.starts_with('?') => { + let var_name = symbol.strip_prefix('?').unwrap(); + Ok(QueryComponent::Variable(if var_name.is_empty() { + None + } else { + Some(var_name.into()) + })) + } _ => Ok(QueryComponent::Exact(T::try_from(value.clone())?)), } } @@ -227,7 +235,7 @@ impl TryFrom<&lexpr::Value> for Query { )) } } - "and" | "or" => { + "and" | "or" | "join" => { let (cons_vec, _) = value.clone().into_vec(); let sub_expressions = &cons_vec[1..]; let values = sub_expressions @@ -239,6 +247,7 @@ impl TryFrom<&lexpr::Value> for Query { Ok(Query::MultiQuery(MultiQuery { qualifier: match symbol.borrow() { "and" => QueryQualifier::And, + "join" => QueryQualifier::Join, _ => QueryQualifier::Or, }, queries, @@ -295,11 +304,10 @@ impl FromStr for Query { } } - #[cfg(test)] mod test { use super::*; - use anyhow::{Result}; + use anyhow::Result; #[test] fn test_matches() -> Result<()> { @@ -307,9 +315,9 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, - value: QueryComponent::Any + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), + value: QueryComponent::Variable(None) })) ); @@ -319,8 +327,8 @@ mod test { query, Query::SingleQuery(QueryPart::Matches(PatternQuery { entity: QueryComponent::Exact(address), - attribute: QueryComponent::Any, - value: QueryComponent::Any + attribute: QueryComponent::Variable(None), + value: QueryComponent::Variable(None) })) ); @@ -328,9 +336,9 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, + entity: QueryComponent::Variable(None), attribute: QueryComponent::Exact("FOO".into()), - value: QueryComponent::Any + value: QueryComponent::Variable(None) })) ); @@ -339,8 +347,8 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), value: QueryComponent::Exact(value) })) ); @@ -348,15 +356,30 @@ mod test { Ok(()) } + #[test] + fn test_joins() -> Result<()> { + let query = "(matches ?a ?b ?)".parse::()?; + assert_eq!( + query, + Query::SingleQuery(QueryPart::Matches(PatternQuery { + entity: QueryComponent::Variable(Some("a".into())), + attribute: QueryComponent::Variable(Some("b".into())), + value: QueryComponent::Variable(None) + })) + ); + + Ok(()) + } + #[test] fn test_in_parse() -> Result<()> { let query = r#"(matches ? (in "FOO" "BAR") ?)"#.parse::()?; assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, + entity: QueryComponent::Variable(None), attribute: QueryComponent::In(vec!("FOO".into(), "BAR".into())), - value: QueryComponent::Any + value: QueryComponent::Variable(None) })) ); @@ -365,8 +388,8 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), value: QueryComponent::In(values) })) ); @@ -376,8 +399,8 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), value: QueryComponent::In(values) })) ); @@ -389,8 +412,8 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), value: QueryComponent::In(values) })) ); @@ -401,8 +424,8 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), value: QueryComponent::In(values) })) ); @@ -417,8 +440,8 @@ mod test { query, Query::SingleQuery(QueryPart::Matches(PatternQuery { entity: QueryComponent::Contains("foo".to_string()), - attribute: QueryComponent::Any, - value: QueryComponent::Any + attribute: QueryComponent::Variable(None), + value: QueryComponent::Variable(None) })) ); @@ -426,9 +449,9 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, + entity: QueryComponent::Variable(None), attribute: QueryComponent::Contains("foo".to_string()), - value: QueryComponent::Any, + value: QueryComponent::Variable(None), })) ); @@ -436,8 +459,8 @@ mod test { assert_eq!( query, Query::SingleQuery(QueryPart::Matches(PatternQuery { - entity: QueryComponent::Any, - attribute: QueryComponent::Any, + entity: QueryComponent::Variable(None), + attribute: QueryComponent::Variable(None), value: QueryComponent::Contains("foo".to_string()) })) ); diff --git a/src/database/mod.rs b/src/database/mod.rs index aa98c14..d3f77c8 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -510,5 +510,17 @@ mod test { .unwrap(); let result = connection.query(query).unwrap(); assert_eq!(result.len(), 1); + + let query = format!( + r#"(join + (matches ?a "FLAVOUR" ?) + (matches ?a "{LABEL_ATTR}" "FOOBAR") + )"# + ) + .parse() + .unwrap(); + let result = connection.query(query).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].value, "STRANGE".into()); } }