diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 946a1f898fc..b855d472eaa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -34,6 +34,7 @@ set(SOUFFLE_SOURCES ast/FunctionalConstraint.cpp ast/IntrinsicFunctor.cpp ast/IterationCounter.cpp + ast/Lattice.cpp ast/Negation.cpp ast/NilConstant.cpp ast/Node.cpp @@ -87,6 +88,7 @@ set(SOUFFLE_SOURCES ast/transform/GroundedTermsChecker.cpp ast/transform/GroundWitnesses.cpp ast/transform/InlineRelations.cpp + ast/transform/InsertLatticeOperations.cpp ast/transform/MagicSet.cpp ast/transform/MaterializeAggregationQueries.cpp ast/transform/MaterializeSingletonAggregation.cpp diff --git a/src/MainDriver.cpp b/src/MainDriver.cpp index 60c7c996f29..aa438f41b42 100644 --- a/src/MainDriver.cpp +++ b/src/MainDriver.cpp @@ -37,6 +37,7 @@ #include "ast/transform/IOAttributes.h" #include "ast/transform/IODefaults.h" #include "ast/transform/InlineRelations.h" +#include "ast/transform/InsertLatticeOperations.h" #include "ast/transform/MagicSet.h" #include "ast/transform/MaterializeAggregationQueries.h" #include "ast/transform/MaterializeSingletonAggregation.h" @@ -479,6 +480,7 @@ Own astTransformationPipeline(Global& glb) // Main pipeline auto pipeline = mk(mk(), mk(), + mk(), mk(), mk(), mk(), diff --git a/src/RelationTag.h b/src/RelationTag.h index 5a850972533..c6d51ca00a3 100644 --- a/src/RelationTag.h +++ b/src/RelationTag.h @@ -58,7 +58,6 @@ enum class RelationRepresentation { BTREE, // use btree data-structure BTREE_DELETE, // use btree_delete data-structure EQREL, // use union data-structure - PROVENANCE, // use custom btree data-structure with provenance extras INFO, // info relation for provenance }; @@ -168,7 +167,6 @@ inline std::ostream& operator<<(std::ostream& os, RelationRepresentation represe case RelationRepresentation::BTREE_DELETE: return os << "btree_delete"; case RelationRepresentation::BRIE: return os << "brie"; case RelationRepresentation::EQREL: return os << "eqrel"; - case RelationRepresentation::PROVENANCE: return os << "provenance"; case RelationRepresentation::INFO: return os << "info"; case RelationRepresentation::DEFAULT: return os; } diff --git a/src/ast/Attribute.cpp b/src/ast/Attribute.cpp index a06fdab5d8e..f30afa45563 100644 --- a/src/ast/Attribute.cpp +++ b/src/ast/Attribute.cpp @@ -14,7 +14,10 @@ namespace souffle::ast { Attribute::Attribute(std::string n, QualifiedName t, SrcLocation loc) - : Node(std::move(loc)), name(std::move(n)), typeName(std::move(t)) {} + : Node(std::move(loc)), name(std::move(n)), typeName(std::move(t)), isLattice(false) {} + +Attribute::Attribute(std::string n, QualifiedName t, bool isLattice, SrcLocation loc) + : Node(std::move(loc)), name(std::move(n)), typeName(std::move(t)), isLattice(isLattice) {} void Attribute::setTypeName(QualifiedName name) { typeName = std::move(name); @@ -22,15 +25,18 @@ void Attribute::setTypeName(QualifiedName name) { void Attribute::print(std::ostream& os) const { os << name << ":" << typeName; + if (isLattice) { + os << "<>"; + } } bool Attribute::equal(const Node& node) const { const auto& other = asAssert(node); - return name == other.name && typeName == other.typeName; + return name == other.name && typeName == other.typeName && isLattice == other.isLattice; } Attribute* Attribute::cloning() const { - return new Attribute(name, typeName, getSrcLoc()); + return new Attribute(name, typeName, isLattice, getSrcLoc()); } } // namespace souffle::ast diff --git a/src/ast/Attribute.h b/src/ast/Attribute.h index ecea916c63f..9466af4805c 100644 --- a/src/ast/Attribute.h +++ b/src/ast/Attribute.h @@ -35,6 +35,7 @@ namespace souffle::ast { class Attribute : public Node { public: Attribute(std::string n, QualifiedName t, SrcLocation loc = {}); + Attribute(std::string n, QualifiedName t, bool isLattice, SrcLocation loc = {}); /** Return attribute name */ const std::string& getName() const { @@ -49,6 +50,10 @@ class Attribute : public Node { /** Set type name */ void setTypeName(QualifiedName name); + bool getIsLattice() const { + return isLattice; + } + protected: void print(std::ostream& os) const override; @@ -63,6 +68,9 @@ class Attribute : public Node { /** Type name */ QualifiedName typeName; + + /** Is lattice element */ + bool isLattice; }; } // namespace souffle::ast diff --git a/src/ast/Component.cpp b/src/ast/Component.cpp index 97129a06f36..f5356fdcff9 100644 --- a/src/ast/Component.cpp +++ b/src/ast/Component.cpp @@ -39,6 +39,15 @@ std::vector Component::getTypes() const { return toPtrVector(types); } +void Component::addLattice(Own t) { + assert(t != nullptr); + lattices.push_back(std::move(t)); +} + +std::vector Component::getLattices() const { + return toPtrVector(lattices); +} + void Component::copyBaseComponents(const Component& other) { baseComponents = clone(other.baseComponents); } diff --git a/src/ast/Component.h b/src/ast/Component.h index 99befb86e92..a6debc78e24 100644 --- a/src/ast/Component.h +++ b/src/ast/Component.h @@ -20,6 +20,7 @@ #include "ast/ComponentInit.h" #include "ast/ComponentType.h" #include "ast/Directive.h" +#include "ast/Lattice.h" #include "ast/Node.h" #include "ast/Relation.h" #include "ast/Type.h" @@ -64,6 +65,11 @@ class Component : public Node { /** Get types */ std::vector getTypes() const; + /** Add lattice */ + void addLattice(Own lat); + + std::vector getLattices() const; + /** Copy base components */ void copyBaseComponents(const Component& other); @@ -129,6 +135,9 @@ class Component : public Node { /** Types declarations */ VecOwn types; + /** Types declarations */ + VecOwn lattices; + /** Relations */ VecOwn relations; diff --git a/src/ast/Lattice.cpp b/src/ast/Lattice.cpp new file mode 100644 index 00000000000..75c83b59ef6 --- /dev/null +++ b/src/ast/Lattice.cpp @@ -0,0 +1,107 @@ +/* + * Souffle - A Datalog Compiler + * Copyright (c) 2023, The Souffle Developers. All rights reserved + * Licensed under the Universal Permissive License v 1.0 as shown at: + * - https://opensource.org/licenses/UPL + * - /licenses/SOUFFLE-UPL.txt + */ + +#include "ast/Lattice.h" +#include "souffle/utility/MiscUtil.h" +#include "souffle/utility/StreamUtil.h" + +#include + +namespace souffle::ast { + +std::optional latticeOperatorFromString(const std::string& str) { + if (str == "Bottom") return Bottom; + if (str == "Top") return Top; + if (str == "Lub") return Lub; + if (str == "Glb") return Glb; + if (str == "Leq") return Leq; + return std::nullopt; +} + +std::string latticeOperatorToString(const LatticeOperator op) { + switch (op) { + case Bottom: return "Bottom"; + case Top: return "Top"; + case Lub: return "Lub"; + case Glb: return "Glb"; + case Leq: return "Leq"; + default: assert(false && "unknown lattice operator"); + } + return ""; +} + +Lattice::Lattice(QualifiedName name, std::map> ops, SrcLocation loc) + : Node(std::move(loc)), name(std::move(name)), operators(std::move(ops)) {} + +void Lattice::setQualifiedName(QualifiedName name) { + this->name = std::move(name); +} + +const std::map Lattice::getOperators() const { + std::map ops; + for (const auto& [op, arg] : operators) { + ops.emplace(std::make_pair(op, arg.get())); + } + return ops; +} + +bool Lattice::hasGlb() const { + return operators.count(Glb) > 0; +} + +bool Lattice::hasLub() const { + return operators.count(Lub) > 0; +} + +bool Lattice::hasBottom() const { + return operators.count(Bottom) > 0; +} + +bool Lattice::hasTop() const { + return operators.count(Top) > 0; +} + +const ast::Argument* Lattice::getLub() const { + return operators.at(Lub).get(); +} + +const ast::Argument* Lattice::getGlb() const { + return operators.at(Glb).get(); +} + +const ast::Argument* Lattice::getBottom() const { + return operators.at(Bottom).get(); +} + +const ast::Argument* Lattice::getTop() const { + return operators.at(Top).get(); +} + +void Lattice::print(std::ostream& os) const { + os << ".lattice " << getQualifiedName() << " {\n "; + bool first = true; + for (const auto& [op, arg] : operators) { + if (!first) { + os << ",\n "; + } + os << latticeOperatorToString(op) << " -> " << *arg; + first = false; + } + os << "\n}"; +} + +bool Lattice::equal(const Node& node) const { + const auto& other = asAssert(node); + return getQualifiedName() == other.getQualifiedName() && equal_targets(operators, other.operators); +} + +Lattice* Lattice::cloning() const { + return new Lattice(getQualifiedName(), clone(operators), getSrcLoc()); +} + +} // namespace souffle::ast diff --git a/src/ast/Lattice.h b/src/ast/Lattice.h new file mode 100644 index 00000000000..71dcb1e970f --- /dev/null +++ b/src/ast/Lattice.h @@ -0,0 +1,77 @@ +/* + * Souffle - A Datalog Compiler + * Copyright (c) 2023, The Souffle Developers. All rights reserved + * Licensed under the Universal Permissive License v 1.0 as shown at: + * - https://opensource.org/licenses/UPL + * - /licenses/SOUFFLE-UPL.txt + */ + +/************************************************************************ + * + * @file Lattice.h + * + * Defines the Lattice class + * + ***********************************************************************/ + +#pragma once + +#include "ast/Argument.h" +#include "ast/Node.h" +#include "ast/QualifiedName.h" +#include "parser/SrcLocation.h" + +#include +#include + +namespace souffle::ast { + +enum LatticeOperator { Bottom = 0, Top, Lub, Glb, Leq }; + +std::optional latticeOperatorFromString(const std::string& str); + +/** + * @class Lattice + * @brief An class to define Lattice attributes for a type + */ +class Lattice : public Node { +public: + Lattice(QualifiedName name, std::map> operators, + SrcLocation loc = {}); + + /** Return type name */ + const QualifiedName& getQualifiedName() const { + return name; + } + + /** Set type name */ + void setQualifiedName(QualifiedName name); + + const std::map getOperators() const; + + bool hasGlb() const; + bool hasLub() const; + bool hasBottom() const; + bool hasTop() const; + + const ast::Argument* getLub() const; + const ast::Argument* getGlb() const; + const ast::Argument* getBottom() const; + const ast::Argument* getTop() const; + +protected: + void print(std::ostream& os) const override; + +private: + bool equal(const Node& node) const override; + + Lattice* cloning() const override; + +private: + /** type name */ + QualifiedName name; + + const std::map> operators; +}; + +} // namespace souffle::ast diff --git a/src/ast/Program.cpp b/src/ast/Program.cpp index d67db8fe53f..623c5d2ad90 100644 --- a/src/ast/Program.cpp +++ b/src/ast/Program.cpp @@ -118,6 +118,10 @@ std::vector Program::getTypes() const { return toPtrVector(types); } +std::vector Program::getLattices() const { + return toPtrVector(lattices); +} + std::vector Program::getRelations() const { return toPtrVector(relations, &RelationInfo::decls); } @@ -231,6 +235,15 @@ void Program::addType(Own type) { types.push_back(std::move(type)); } +void Program::addLattice(Own lattice) { + assert(lattice != nullptr); + [[maybe_unused]] auto* existingLattice = getIf(getLattices(), [&](const Lattice* current) { + return current->getQualifiedName() == lattice->getQualifiedName(); + }); + assert(existingLattice == nullptr && "Redefinition of lattice!"); + lattices.push_back(std::move(lattice)); +} + void Program::addPragma(Own pragma) { assert(pragma && "NULL pragma"); pragmas.push_back(std::move(pragma)); @@ -259,6 +272,7 @@ void Program::apply(const NodeMapper& map) { mapAll(instantiations, map); mapAll(functors, map); mapAll(types, map); + mapAll(lattices, map); mapAll(relations, &RelationInfo::decls, map); mapAll(relations, &RelationInfo::clauses, map); mapAll(relations, &RelationInfo::directives, map); @@ -271,6 +285,7 @@ Node::NodeVec Program::getChildren() const { append(res, makePtrRange(instantiations)); append(res, makePtrRange(functors)); append(res, makePtrRange(types)); + append(res, makePtrRange(lattices)); append(res, relations, &RelationInfo::decls); append(res, relations, &RelationInfo::clauses); append(res, relations, &RelationInfo::directives); @@ -286,6 +301,7 @@ void Program::print(std::ostream& os) const { show(components); show(instantiations); show(types); + show(lattices); show(functors); show(getRelations()); show(getClauses(), "\n\n"); @@ -300,6 +316,7 @@ bool Program::equal(const Node& node) const { equal_targets(instantiations, other.instantiations) && equal_targets(functors, other.functors) && equal_targets(types, other.types) && + equal_targets(lattices, other.lattices) && equal_targets_map(relations, other.relations, [](auto& a, auto& b) { return equal_targets(a.decls , b.decls ) && equal_targets(a.clauses , b.clauses ) && @@ -324,6 +341,7 @@ Program* Program::cloning() const { res->components = clone(components); res->instantiations = clone(instantiations); res->types = clone(types); + res->lattices = clone(lattices); res->functors = clone(functors); res->relations = clone(relations); return res; diff --git a/src/ast/Program.h b/src/ast/Program.h index c3475ef458d..b6544e9b978 100644 --- a/src/ast/Program.h +++ b/src/ast/Program.h @@ -109,6 +109,9 @@ class Program : public Node { /** Return types */ std::vector getTypes() const; + /** Return lattices */ + std::vector getLattices() const; + /** Return relations */ std::vector getRelations() const; @@ -183,6 +186,9 @@ class Program : public Node { /** Add a type declaration */ void addType(Own type); + /** Add a lattice declaration */ + void addLattice(Own lattice); + /** * Remove a clause by identity. The clause must be owned by the program. * It is not expected that there are useful cases where some are not owned, and it is often @@ -240,6 +246,9 @@ class Program : public Node { /** Program types */ VecOwn types; + /** Program lattices */ + VecOwn lattices; + /** Program relation declartions, clauses, and directives */ RelationInfoMap relations; diff --git a/src/ast/Relation.h b/src/ast/Relation.h index 4fb1aade201..64b8f774bed 100644 --- a/src/ast/Relation.h +++ b/src/ast/Relation.h @@ -54,6 +54,15 @@ class Relation : public Node { return attributes.size(); } + /** Return the arity of this relation */ + std::size_t getAuxiliaryArity() const { + std::size_t arity = 0; + for (const auto& a : attributes) { + arity += a->getIsLattice() ? 1 : 0; + } + return arity; + } + /** Set relation attributes */ void setAttributes(VecOwn attrs); diff --git a/src/ast/analysis/Ground.cpp b/src/ast/analysis/Ground.cpp index 52d7c94778e..bfb7ba184e7 100644 --- a/src/ast/analysis/Ground.cpp +++ b/src/ast/analysis/Ground.cpp @@ -154,8 +154,24 @@ BoolDisjunctConstraint imply(const std::vector& vars, const Boo struct GroundednessAnalysis : public ConstraintAnalysis { Program& program; std::set ignore; - - GroundednessAnalysis(const TranslationUnit& tu) : program(tu.getProgram()) {} + std::map> latticeAttributes; + bool isLatticeTransformerPass; + + GroundednessAnalysis(const TranslationUnit& tu, bool isLatticeTransformerPass) + : program(tu.getProgram()), isLatticeTransformerPass(isLatticeTransformerPass) { + if (isLatticeTransformerPass) { + for (const Relation* rel : program.getRelations()) { + const auto attributes = rel->getAttributes(); + const auto& name = rel->getQualifiedName(); + for (std::size_t i = 0; i < attributes.size(); i++) { + if (attributes[i]->getIsLattice()) { + const auto type = attributes[i]->getTypeName(); + latticeAttributes[name].insert(i); + } + } + } + } + } // atoms are producing grounded variables void visit_(type_identity, const Atom& cur) override { @@ -164,9 +180,14 @@ struct GroundednessAnalysis : public ConstraintAnalysis { return; } - // all arguments are grounded - for (const auto& arg : cur.getArguments()) { - addConstraint(isTrue(getVar(arg))); + // all arguments are grounded except lattice arguments + const auto& name = cur.getQualifiedName(); + const auto& args = cur.getArguments(); + for (std::size_t i = 0; i < cur.getArity(); i++) { + if (!isLatticeTransformerPass || !latticeAttributes.count(name) || + !latticeAttributes[name].count(i)) { + addConstraint(isTrue(getVar(args[i]))); + } } } @@ -265,9 +286,10 @@ struct GroundednessAnalysis : public ConstraintAnalysis { /*** * computes for variables in the clause whether they are grounded */ -std::map getGroundedTerms(const TranslationUnit& tu, const Clause& clause) { +std::map getGroundedTerms( + const TranslationUnit& tu, const Clause& clause, bool isLatticeTransformerPass) { // run analysis on given clause - return GroundednessAnalysis(tu).analyse(clause); + return GroundednessAnalysis(tu, isLatticeTransformerPass).analyse(clause); } } // namespace souffle::ast::analysis diff --git a/src/ast/analysis/Ground.h b/src/ast/analysis/Ground.h index f6ca7333d8d..8ac8689e35d 100644 --- a/src/ast/analysis/Ground.h +++ b/src/ast/analysis/Ground.h @@ -32,6 +32,7 @@ namespace souffle::ast::analysis { * @return a map mapping each contained argument to a boolean indicating * whether the argument represents a grounded value or not */ -std::map getGroundedTerms(const TranslationUnit& tu, const Clause& clause); +std::map getGroundedTerms( + const TranslationUnit& tu, const Clause& clause, bool isLatticeTransformerPass = false); } // namespace souffle::ast::analysis diff --git a/src/ast/transform/ComponentInstantiation.cpp b/src/ast/transform/ComponentInstantiation.cpp index bf1e46b629c..69567ae0b29 100644 --- a/src/ast/transform/ComponentInstantiation.cpp +++ b/src/ast/transform/ComponentInstantiation.cpp @@ -23,6 +23,7 @@ #include "ast/ComponentInit.h" #include "ast/ComponentType.h" #include "ast/Directive.h" +#include "ast/Lattice.h" #include "ast/Node.h" #include "ast/Program.h" #include "ast/QualifiedName.h" @@ -58,6 +59,7 @@ static const unsigned int MAX_INSTANTIATION_DEPTH = 100; */ struct ComponentContent { VecOwn types; + VecOwn lattices; VecOwn relations; VecOwn directives; VecOwn clauses; @@ -77,6 +79,22 @@ struct ComponentContent { types.push_back(std::move(type)); } + void add(Own& lattice, ErrorReport& report) { + // add to result content (check existence first) + auto foundItem = + std::find_if(lattices.begin(), lattices.end(), [&](const Own& element) { + return (element->getQualifiedName() == lattice->getQualifiedName()); + }); + if (foundItem != lattices.end()) { + Diagnostic err(Diagnostic::Type::ERROR, + DiagnosticMessage("Redefinition of lattice " + toString(lattice->getQualifiedName()), + lattice->getSrcLoc()), + {DiagnosticMessage("Previous definition", (*foundItem)->getSrcLoc())}); + report.addDiagnostic(err); + } + lattices.push_back(std::move(lattice)); + } + void add(Own& rel, ErrorReport& report) { // add to result content (check existence first) auto foundItem = std::find_if(relations.begin(), relations.end(), [&](const Own& element) { @@ -164,6 +182,11 @@ void collectContent(Program& program, const Component& component, const TypeBind res.add(type, report); } + // process lattices + for (auto& lattice : content.lattices) { + res.add(lattice, report); + } + // process relations for (auto& rel : content.relations) { res.add(rel, report); @@ -256,6 +279,17 @@ void collectContent(Program& program, const Component& component, const TypeBind res.add(type, report); } + for (const auto& cur : component.getLattices()) { + // create a clone + Own lattice(clone(cur)); + + auto&& newName = binding.find(lattice->getQualifiedName()); + if (!newName.empty()) { + lattice->setQualifiedName(newName); + } + res.add(lattice, report); + } + // and the local relations // (replacing formal parameters with actual parameters) for (const auto& cur : component.getRelations()) { @@ -355,6 +389,11 @@ ComponentContent getInstantiatedContent(Program& program, const ComponentInit& c res.add(type, report); } + // add types + for (auto& lattice : nestedContent.lattices) { + res.add(lattice, report); + } + // add relations for (auto& rel : nestedContent.relations) { res.add(rel, report); @@ -389,6 +428,11 @@ ComponentContent getInstantiatedContent(Program& program, const ComponentInit& c }); } + for (const auto& cur : res.lattices) { + auto newName = componentInit.getInstanceName() + cur->getQualifiedName(); + cur->setQualifiedName(newName); + } + // update relation names std::map relationNameMapping; for (const auto& cur : res.relations) { @@ -552,6 +596,9 @@ bool ComponentInstantiationTransformer::transform(TranslationUnit& translationUn for (auto& type : content.types) { program.addType(std::move(type)); } + for (auto& lattice : content.lattices) { + program.addLattice(std::move(lattice)); + } for (auto& rel : content.relations) { program.addRelation(std::move(rel)); } diff --git a/src/ast/transform/InsertLatticeOperations.cpp b/src/ast/transform/InsertLatticeOperations.cpp new file mode 100644 index 00000000000..e9e4e405bdf --- /dev/null +++ b/src/ast/transform/InsertLatticeOperations.cpp @@ -0,0 +1,347 @@ +/* + * Souffle - A Datalog Compiler + * Copyright (c) 2023 The Souffle Developers. All rights reserved + * Licensed under the Universal Permissive License v 1.0 as shown at: + * - https://opensource.org/licenses/UPL + * - /licenses/SOUFFLE-UPL.txt + */ + +/************************************************************************ + * + * @file InsertLatticeOperations.cpp + * + * Implements AST transformation related to the support of lattices + * + ***********************************************************************/ + +#include "ast/transform/InsertLatticeOperations.h" +#include "ast/Atom.h" +#include "ast/BinaryConstraint.h" +#include "ast/Clause.h" +#include "ast/Constraint.h" +#include "ast/Negation.h" +#include "ast/QualifiedName.h" +#include "ast/UnnamedVariable.h" +#include "ast/UserDefinedFunctor.h" +#include "ast/Variable.h" +#include "ast/analysis/Ground.h" +#include "ast/utility/Visitor.h" +#include + +namespace souffle::ast::transform { + +/* + * If a non-negated body atom has a lattice argument that is grounded by other atoms and constraints, + * e.g. `R(a, b, $B(x))`, then: + * - Given that `$B(x)` is a lattice argument: + * `R(a, b, $B(x))` holds <=> the current value `lat` in `R(a, b, lat)` is greater or equal to `$B(x)` in + * the lattice Consequently, we must translate the atom `R(a, ,b $B(x))` into `R(a, b, lat), GLB(lat, $B(x)) = + * $B(x)` or `R(a, b, lat), LEQ($B(x), lat) != 0`. + */ +struct ReplaceGroundedLatticeArguments : public NodeMapper { + ReplaceGroundedLatticeArguments(std::map& toReplace, + VecOwn& newConstraints) + : toReplace(toReplace), newConstraints(newConstraints){}; + + Own operator()(Own node) const override { + if (auto arg = as(node)) { + if (toReplace.count(arg) > 0) { + Point p = arg->getSrcLoc().start; + std::string lat = ""; + // add Glb constraint + assert(toReplace.count(arg) > 0); + const ast::Lattice* lattice = toReplace.at(arg); + assert(lattice->hasGlb()); + auto glbName = as(lattice->getGlb())->getName(); + VecOwn args; + args.push_back(mk(lat, arg->getSrcLoc())); + args.push_back(clone(arg)); + auto glb = mk(glbName, std::move(args), arg->getSrcLoc()); + newConstraints.push_back(mk( + BinaryConstraintOp::EQ, std::move(glb), clone(arg), arg->getSrcLoc())); + return mk(lat, arg->getSrcLoc()); + } else { + return node; + } + } + node->apply(*this); + return node; + } + +private: + std::map toReplace; + VecOwn& newConstraints; +}; + +struct varInfo { + const ast::Lattice* lattice; + // lattice variables in body atoms + std::set variables; + // set of newly introduced variable names + std::set glbs; +}; + +/* + * If a non-negated body atom has a lattice argument `x` that is not grounded, + * e.g. `R(a, b, x)`: + * Then, we gather all occurences of `x` that appear as lattice arguments, + * replace these occurences with distinct variable names `x_1`, `x_2`, ... `x_n` + * and we compute the GLB of all these occurences and store in variable `x`, + * and finally add the constraint `x != BOTTOM`. + * If the variable `x` is used in the atom head, it can remain unchanged. + * If some literals have constraints on `x`, e.g. `@myfunctor(x) > 0`: + * This means we would need to find the greatest lower bound of `x_1`, `x_2`, ... `x_n` + * verifying the constraints, which is not possible => the clause cannot be translated. + */ +struct ReplaceUngroundedLatticeArguments : public NodeMapper { + ReplaceUngroundedLatticeArguments(varInfo& infos) : infos(infos){}; + + Own operator()(Own node) const override { + if (auto arg = as(node)) { + if (infos.variables.count(arg) > 0) { + std::string varName; + varName = arg->getName() + "_"; + infos.glbs.insert(varName); + return mk(varName, arg->getSrcLoc()); + } + } + node->apply(*this); + return node; + } + +private: + varInfo& infos; +}; + +/* + * If a negated body atom has lattice argument `arg1, ..., argN`, + * e.g. `!R(a, b, arg1, arg2)`, then + * - the arguments `arg1`, ... `argN` must be bound by other atoms + * - We must translate the negated atom `!R(a, b, arg1, arg2)` + * into the disjunction : + * ( + * !R(a, b, _, _) ; + * R(a, b, lat1, _), GLB(lat1, arg1) = BOTTOM; + * R(a, b, _, lat2), GLB(lat2, arg2) = BOTTOM + * ) + */ + +RuleBody LatticeTransformer::translateNegatedAtom(ast::Atom& atom) { + auto args = atom.getArguments(); + auto range = latticeAttributes.equal_range(atom.getQualifiedName()); + + std::map latticeIndexes; + for (auto it = range.first; it != range.second; it++) { + latticeIndexes.insert(it->second); + } + + std::size_t arity = atom.getArity(); + VecOwn negatedArgs; + for (std::size_t i = 0; i < arity; i++) { + if (latticeIndexes.count(i)) { + negatedArgs.push_back(mk()); + } else { + negatedArgs.push_back(clone(args[i])); + } + } + auto negated = mk(atom.getQualifiedName(), std::move(negatedArgs), atom.getSrcLoc()); + RuleBody body = RuleBody::atom(std::move(negated)).negated(); + + // create one disjunct per argument + for (const auto& [index, type] : latticeIndexes) { + const ast::Lattice* lattice = lattices.at(type); + VecOwn arguments; + const auto& arg = args[index]; + if (isA(arg)) { + continue; + } + const auto& sloc = arg->getSrcLoc(); + std::string lat = + ""; + for (std::size_t i = 0; i < arity; i++) { + if (latticeIndexes.count(i)) { + if (i == index) { + arguments.push_back(mk(lat)); + } else { + arguments.push_back(mk()); + } + } else { + arguments.push_back(clone(args[i])); + } + } + VecOwn glbArgs; + glbArgs.push_back(mk(lat, sloc)); + glbArgs.push_back(clone(arg)); + auto glbName = as(lattice->getGlb())->getName(); + auto glb = mk(glbName, std::move(glbArgs), sloc); + auto constraint = mk( + BinaryConstraintOp::EQ, std::move(glb), clone(lattice->getBottom()), sloc); + auto conjunct = + RuleBody::atom(mk(atom.getQualifiedName(), std::move(arguments), atom.getSrcLoc())); + conjunct.conjunct(RuleBody::constraint(std::move(constraint))); + body.disjunct(std::move(conjunct)); + } + + return body; +} + +bool LatticeTransformer::translateClause( + TranslationUnit& translationUnit, ErrorReport& report, ast::Clause* clause) { + bool changed; + + // Set of atoms that are negated in the clause body and contain lattice arguments + std::set negatedLatticeAtoms; + // set of lattice arguments that are grounded, and their corresponding lattice + std::map groundedLatticeArguments; + // stores information on ungrounded lattice variable names + std::map ungroundedLatticeArguments; + + // Compute grounded/ungrounded arguments + auto isGrounded = analysis::getGroundedTerms(translationUnit, *clause, true); + + // Identify which body atoms are negated + std::set negated; + visit(clause->getBodyLiterals(), [&](const Negation& negation) { negated.insert(negation.getAtom()); }); + + // Identify lattice arguments in body atoms + visit(clause->getBodyLiterals(), [&](const Atom& atom) { + auto range = latticeAttributes.equal_range(atom.getQualifiedName()); + auto args = atom.getArguments(); + for (auto it = range.first; it != range.second; it++) { + const ast::Argument* arg = args[it->second.first]; + assert(lattices.count(it->second.second) > 0); + const ast::Lattice* lattice = lattices.at(it->second.second); + if (isA(arg)) { + // nothing to do + } else if (negated.count(&atom)) { + negatedLatticeAtoms.insert(&atom); + } else if (isGrounded[arg]) { + groundedLatticeArguments.insert(std::make_pair(arg, lattice)); + } else if (const auto* var = as(arg)) { + ungroundedLatticeArguments[var->getName()].variables.insert(var); + ungroundedLatticeArguments[var->getName()].lattice = lattice; + } else { + report.addError("Lattice argument is not grounded", arg->getSrcLoc()); + } + } + }); + + VecOwn constraints; + + // Update grounded lattice arguments + ReplaceGroundedLatticeArguments update(groundedLatticeArguments, constraints); + for (auto* literal : clause->getBodyLiterals()) { + literal->apply(update); + } + + // Update ungrounded lattice arguments + for (auto& it : ungroundedLatticeArguments) { + auto& infos = it.second; + auto& name = it.first; + if (infos.variables.size() > 1) { + changed = true; + + // First, we must make sure that the ungrounded variables are only used as lattice arguments + visit(clause->getBodyLiterals(), [&](const Variable& variable) { + if (variable.getName() == name && !infos.variables.count(&variable)) { + report.addError("Ungrounded lattice variable cannot be used in other literals", + variable.getSrcLoc()); + } + }); + + ReplaceUngroundedLatticeArguments update(infos); + for (auto* literal : clause->getBodyLiterals()) { + literal->apply(update); + } + // create new constraints + Own glb; + const auto& lattice = infos.lattice; + assert(lattice->hasGlb()); + auto glbName = as(lattice->getGlb())->getName(); + for (const std::string& name : infos.glbs) { + if (!glb) { + // first element + glb = mk(name, clause->getSrcLoc()); + } else { + VecOwn args; + args.push_back(mk(name, clause->getSrcLoc())); + args.push_back(std::move(glb)); + glb = mk(glbName, std::move(args), clause->getSrcLoc()); + } + } + // x = GLB(...) + constraints.push_back(mk(BinaryConstraintOp::EQ, + mk(name, clause->getSrcLoc()), std::move(glb), clause->getSrcLoc())); + // x != BOTTOM + constraints.push_back( + mk(BinaryConstraintOp::NE, mk(name, clause->getSrcLoc()), + clone(lattice->getBottom()), clause->getSrcLoc())); + } + } + clause->addToBody(std::move(constraints)); + + // updated negated atoms containing lattice arguments + if (!negatedLatticeAtoms.empty()) { + RuleBody body = RuleBody::getTrue(); + for (const auto* lit : clause->getBodyLiterals()) { + if (const auto& neg = as(lit)) { + body.conjunct(translateNegatedAtom(*neg->getAtom())); + } else if (const auto& atom = as(lit)) { + body.conjunct(RuleBody::atom(clone(atom))); + } else if (const auto& cst = as(lit)) { + body.conjunct(RuleBody::constraint(clone(cst))); + } else { + assert(false && "unreachable"); + } + } + VecOwn clauses = body.toClauseBodies(); + for (auto& c : clauses) { + c->setHead(clone(clause->getHead())); + } + translationUnit.getProgram().addClauses(std::move(clauses)); + translationUnit.getProgram().removeClause(*clause); + } + return changed; +} + +bool LatticeTransformer::transform(TranslationUnit& translationUnit) { + bool changed = false; + Program& program = translationUnit.getProgram(); + ErrorReport& report = translationUnit.getErrorReport(); + + if (program.getLattices().empty()) { + return changed; + } + + // populates map type name -> lattice + for (const ast::Lattice* lattice : program.getLattices()) { + // We ignore lattices lacking a definition for Glb and Bottom + if (lattice->hasGlb() && lattice->hasBottom()) { + lattices.emplace(lattice->getQualifiedName(), lattice); + } + } + + for (const Relation* rel : program.getRelations()) { + const auto attributes = rel->getAttributes(); + for (std::size_t i = 0; i < attributes.size(); i++) { + if (attributes[i]->getIsLattice()) { + const auto type = attributes[i]->getTypeName(); + if (lattices.count(type)) { + latticeAttributes.emplace( + std::make_pair(rel->getQualifiedName(), std::make_pair(i, type))); + } + } + } + } + if (latticeAttributes.empty()) { + return false; + } + + for (Clause* clause : program.getClauses()) { + changed |= translateClause(translationUnit, report, clause); + } + + return changed; +} + +} // namespace souffle::ast::transform diff --git a/src/ast/transform/InsertLatticeOperations.h b/src/ast/transform/InsertLatticeOperations.h new file mode 100644 index 00000000000..aecb19a9ba0 --- /dev/null +++ b/src/ast/transform/InsertLatticeOperations.h @@ -0,0 +1,52 @@ +/* + * Souffle - A Datalog Compiler + * Copyright (c) 2023 The Souffle Developers. All rights reserved + * Licensed under the Universal Permissive License v 1.0 as shown at: + * - https://opensource.org/licenses/UPL + * - /licenses/SOUFFLE-UPL.txt + */ + +/************************************************************************ + * + * @file InsertLatticeOperations.h + * + * AST transformation related to the support of lattices + * + ***********************************************************************/ + +#pragma once + +#include "ast/QualifiedName.h" +#include "ast/TranslationUnit.h" +#include "ast/transform/Transformer.h" +#include "parser/ParserUtils.h" +#include +#include + +namespace souffle::ast::transform { + +/** + * Transformation pass to insert GLBs and checks != Bottom on lattice type attributes + */ +class LatticeTransformer : public Transformer { +public: + std::string getName() const override { + return "InlineRelationsTransformer"; + } + +private: + LatticeTransformer* cloning() const override { + return new LatticeTransformer(); + } + + bool translateClause(TranslationUnit& translationUnit, ErrorReport& report, ast::Clause* clause); + + bool transform(TranslationUnit& translationUnit) override; + + RuleBody translateNegatedAtom(ast::Atom& atom); + + std::multimap> latticeAttributes; + std::map lattices; +}; + +} // namespace souffle::ast::transform diff --git a/src/ast/transform/SemanticChecker.cpp b/src/ast/transform/SemanticChecker.cpp index 0e076159130..8d8c7ee9ba2 100644 --- a/src/ast/transform/SemanticChecker.cpp +++ b/src/ast/transform/SemanticChecker.cpp @@ -38,6 +38,7 @@ #include "ast/IntrinsicAggregator.h" #include "ast/IntrinsicFunctor.h" #include "ast/IterationCounter.h" +#include "ast/Lattice.h" #include "ast/Literal.h" #include "ast/Negation.h" #include "ast/NilConstant.h" @@ -124,6 +125,7 @@ struct SemanticCheckerImpl { void checkRelationFunctionalDependencies(const Relation& relation); void checkRelation(const Relation& relation); void checkType(ast::Attribute const& attr, std::string const& name = {}); + void checkLatticeDeclaration(const Lattice& lattice); void checkFunctorDeclaration(const FunctorDeclaration& decl); void checkNamespaces(); @@ -173,6 +175,9 @@ SemanticCheckerImpl::SemanticCheckerImpl(TranslationUnit& tu) : tu(tu) { for (auto* rel : program.getRelations()) { checkRelation(*rel); } + for (auto* lattice : program.getLattices()) { + checkLatticeDeclaration(*lattice); + } for (auto* clause : program.getClauses()) { checkClause(*clause); } @@ -244,10 +249,10 @@ void SemanticCheckerImpl::checkAtom(const Clause& parent, const Atom& atom) { report.addError("Undefined relation " + toString(atom.getQualifiedName()), atom.getSrcLoc()); return; } - - if (r->getArity() != atom.getArity()) { + std::size_t arity = r->getArity(); + if (arity != atom.getArity()) { report.addError("Mismatching arity of relation " + toString(atom.getQualifiedName()) + " (expected " + - toString(r->getArity()) + ", got " + toString(atom.getArity()) + ")", + toString(arity) + ", got " + toString(atom.getArity()) + ")", atom.getSrcLoc()); } @@ -658,11 +663,43 @@ void SemanticCheckerImpl::checkFunctorDeclaration(const FunctorDeclaration& decl } } +void SemanticCheckerImpl::checkLatticeDeclaration(const Lattice& lattice) { + const auto& name = lattice.getQualifiedName(); + auto* existingType = getIf( + program.getTypes(), [&](const ast::Type* type) { return type->getQualifiedName() == name; }); + if (!existingType) { + report.addError(tfm::format("Undefined type %s", name), lattice.getSrcLoc()); + } + if (lattice.hasLub()) { + if (!isA(lattice.getLub())) { + report.addError( + tfm::format("Lattice operator Lub must be a user-defined functor"), lattice.getSrcLoc()); + } + } else { + report.addError(tfm::format("Lattice %s<> does not define Lub", name), lattice.getSrcLoc()); + } + if (lattice.hasGlb()) { + if (!isA(lattice.getGlb())) { + report.addError( + tfm::format("Lattice operator Glb must be a user-defined functor"), lattice.getSrcLoc()); + } + } else { + report.addWarning(WarnType::LatticeMissingOperator, + tfm::format("Lattice %s<> does not define Glb", name), lattice.getSrcLoc()); + } + if (!lattice.hasBottom()) { + report.addWarning(WarnType::LatticeMissingOperator, + tfm::format("Lattice %s<> does not define Bottom", name), lattice.getSrcLoc()); + } +} + void SemanticCheckerImpl::checkRelationDeclaration(const Relation& relation) { const auto& attributes = relation.getAttributes(); - assert(attributes.size() == relation.getArity() && "mismatching attribute size and arity"); + const std::size_t arity = relation.getArity(); + std::size_t firstAuxiliary = arity - relation.getAuxiliaryArity(); - for (std::size_t i = 0; i < relation.getArity(); i++) { + assert(attributes.size() == arity && "mismatching attribute size and arity"); + for (std::size_t i = 0; i < arity; i++) { Attribute* attr = attributes[i]; checkType(*attr); @@ -672,6 +709,25 @@ void SemanticCheckerImpl::checkRelationDeclaration(const Relation& relation) { report.addError(tfm::format("Doubly defined attribute name %s", *attr), attr->getSrcLoc()); } } + + /* check that lattice elements are always the last */ + if (i < firstAuxiliary && attr->getIsLattice()) { + report.addError( + tfm::format( + "Lattice attribute %s should be placed after all non-lattice attributes", *attr), + attr->getSrcLoc()); + } + + /* check that lattice attributes have a correct lattice definition */ + if (attr->getIsLattice()) { + const auto& typeName = attr->getTypeName(); + auto* existingType = getIf(program.getLattices(), + [&](const ast::Lattice* lattice) { return lattice->getQualifiedName() == typeName; }); + if (!existingType) { + report.addError( + tfm::format("Missing lattice definition for type %s", typeName), attr->getSrcLoc()); + } + } } } @@ -728,7 +784,21 @@ void SemanticCheckerImpl::checkRelation(const Relation& relation) { } if (relation.getRepresentation() == RelationRepresentation::BTREE_DELETE && relation.getArity() == 0) { report.addError("Subsumptive relation \"" + toString(relation.getQualifiedName()) + - "\" must not be a nullary relation", + "\" must not be a nullary relation", + relation.getSrcLoc()); + } + + if (hasSubsumptiveRule && relation.getAuxiliaryArity()) { + report.addError("Relation \"" + toString(relation.getQualifiedName()) + + "\" must not have both subsumptive rules and lattice arguments", + relation.getSrcLoc()); + } + + if (relation.getAuxiliaryArity() && + (relation.getRepresentation() != RelationRepresentation::BTREE && + relation.getRepresentation() != RelationRepresentation::DEFAULT)) { + report.addError( + "Relation \"" + toString(relation.getQualifiedName()) + "\" must have a btree representation", relation.getSrcLoc()); } diff --git a/src/ast/utility/Visitor.h b/src/ast/utility/Visitor.h index 1b07687c275..beb08643e6d 100644 --- a/src/ast/utility/Visitor.h +++ b/src/ast/utility/Visitor.h @@ -39,6 +39,7 @@ #include "ast/IntrinsicAggregator.h" #include "ast/IntrinsicFunctor.h" #include "ast/IterationCounter.h" +#include "ast/Lattice.h" #include "ast/Literal.h" #include "ast/Negation.h" #include "ast/NilConstant.h" @@ -185,6 +186,7 @@ struct Visitor : souffle::detail::VisitorBase { SOUFFLE_VISITOR_LINK(Relation, Node); SOUFFLE_VISITOR_LINK(Pragma, Node); SOUFFLE_VISITOR_LINK(FunctorDeclaration, Node); + SOUFFLE_VISITOR_LINK(ast::Lattice, Node); }; } // namespace souffle::ast diff --git a/src/ast2ram/ClauseTranslator.h b/src/ast2ram/ClauseTranslator.h index 93021f514f6..d157804276d 100644 --- a/src/ast2ram/ClauseTranslator.h +++ b/src/ast2ram/ClauseTranslator.h @@ -39,6 +39,8 @@ class TranslatorContext; enum TranslationMode { DEFAULT, + Auxiliary, + // Subsumptive clauses // // R(x0) <= R(x1) :- body. diff --git a/src/ast2ram/provenance/UnitTranslator.cpp b/src/ast2ram/provenance/UnitTranslator.cpp index 6c105d882d7..1691422c10d 100644 --- a/src/ast2ram/provenance/UnitTranslator.cpp +++ b/src/ast2ram/provenance/UnitTranslator.cpp @@ -24,12 +24,14 @@ #include "ast2ram/utility/TranslatorContext.h" #include "ast2ram/utility/Utils.h" #include "ast2ram/utility/ValueIndex.h" +#include "ram/AbstractOperator.h" #include "ram/Call.h" #include "ram/DebugInfo.h" #include "ram/ExistenceCheck.h" #include "ram/Expression.h" #include "ram/Filter.h" #include "ram/Insert.h" +#include "ram/IntrinsicOperator.h" #include "ram/LogRelationTimer.h" #include "ram/MergeExtend.h" #include "ram/Negation.h" @@ -69,18 +71,12 @@ Own UnitTranslator::generateProgram(const ast::TranslationUnit& t Own UnitTranslator::createRamRelation( const ast::Relation* baseRelation, std::string ramRelationName) const { - auto arity = baseRelation->getArity(); + auto relation = seminaive::UnitTranslator::createRamRelation(baseRelation, ramRelationName); - // All relations in a provenance program should have a provenance data structure - auto representation = RelationRepresentation::PROVENANCE; - - // Add in base relation information - std::vector attributeNames; - std::vector attributeTypeQualifiers; - for (const auto& attribute : baseRelation->getAttributes()) { - attributeNames.push_back(attribute->getName()); - attributeTypeQualifiers.push_back(context->getAttributeTypeQualifier(attribute->getTypeName())); - } + std::size_t arity = relation->getArity(); + std::size_t auxiliaryArity = relation->getAuxiliaryArity(); + std::vector attributeNames = relation->getAttributeNames(); + std::vector attributeTypeQualifiers = relation->getAttributeTypes(); // Add in provenance information attributeNames.push_back("@rule_number"); @@ -89,8 +85,8 @@ Own UnitTranslator::createRamRelation( attributeNames.push_back("@level_number"); attributeTypeQualifiers.push_back("i:number"); - return mk( - ramRelationName, arity + 2, 2, attributeNames, attributeTypeQualifiers, representation); + return mk(ramRelationName, arity + 2, auxiliaryArity + 2, attributeNames, + attributeTypeQualifiers, relation->getRepresentation()); } std::string UnitTranslator::getInfoRelationName(const ast::Clause* clause) const { diff --git a/src/ast2ram/seminaive/UnitTranslator.cpp b/src/ast2ram/seminaive/UnitTranslator.cpp index 07b4fd9104c..164278f077f 100644 --- a/src/ast2ram/seminaive/UnitTranslator.cpp +++ b/src/ast2ram/seminaive/UnitTranslator.cpp @@ -20,12 +20,14 @@ #include "ast/Relation.h" #include "ast/SubsumptiveClause.h" #include "ast/TranslationUnit.h" +#include "ast/UserDefinedFunctor.h" #include "ast/analysis/TopologicallySortedSCCGraph.h" #include "ast/utility/Utils.h" #include "ast/utility/Visitor.h" #include "ast2ram/ClauseTranslator.h" #include "ast2ram/utility/TranslatorContext.h" #include "ast2ram/utility/Utils.h" +#include "ram/Aggregate.h" #include "ram/Assign.h" #include "ram/Call.h" #include "ram/Clear.h" @@ -60,7 +62,10 @@ #include "ram/Swap.h" #include "ram/TranslationUnit.h" #include "ram/TupleElement.h" +#include "ram/UndefValue.h" #include "ram/UnsignedConstant.h" +#include "ram/UserDefinedAggregator.h" +#include "ram/UserDefinedOperator.h" #include "ram/Variable.h" #include "ram/utility/Utils.h" #include "reports/DebugReport.h" @@ -113,7 +118,8 @@ Own UnitTranslator::generateNonRecursiveRelation(const ast::Rela } // Translate clause - Own rule = context->translateNonRecursiveClause(*clause); + TranslationMode mode = rel.getAuxiliaryArity() > 0 ? Auxiliary : DEFAULT; + Own rule = context->translateNonRecursiveClause(*clause, mode); // Add logging if (glb->config().has("profile")) { @@ -176,8 +182,16 @@ Own UnitTranslator::generateStratum(std::size_t scc) const { const auto* rel = *sccRelations.begin(); appendStmt(current, generateNonRecursiveRelation(*rel)); + // lub auxiliary arities using the @lub relation + if (rel->getAuxiliaryArity() > 0) { + std::string mainRelation = getConcreteRelationName(rel->getQualifiedName()); + std::string newRelation = getNewRelationName(rel->getQualifiedName()); + appendStmt(current, generateStratumLubSequence(*rel, newRelation, mainRelation)); + appendStmt(current, mk(newRelation)); + } + // issue delete sequence for non-recursive subsumptions - appendStmt(current, generateNonRecursiveDelete(sccRelations)); + appendStmt(current, generateNonRecursiveDelete(*rel)); } // Get all non-recursive relation statements @@ -404,50 +418,47 @@ VecOwn UnitTranslator::generateClauseVersions( return clauseVersions; } -Own UnitTranslator::generateNonRecursiveDelete(const ast::RelationSet& scc) const { +Own UnitTranslator::generateNonRecursiveDelete(const ast::Relation& rel) const { VecOwn code; // Generate code for non-recursive subsumption - for (const ast::Relation* rel : scc) { - if (!context->hasSubsumptiveClause(rel->getQualifiedName())) { + if (!context->hasSubsumptiveClause(rel.getQualifiedName())) { + return mk(std::move(code)); + } + + std::string mainRelation = getConcreteRelationName(rel.getQualifiedName()); + std::string deleteRelation = getDeleteRelationName(rel.getQualifiedName()); + + // Compute subsumptive deletions for non-recursive rules + for (auto clause : context->getProgram()->getClauses(rel)) { + if (!isA(clause)) { continue; } - std::string mainRelation = getConcreteRelationName(rel->getQualifiedName()); - std::string deleteRelation = getDeleteRelationName(rel->getQualifiedName()); - - // Compute subsumptive deletions for non-recursive rules - for (auto clause : context->getProgram()->getClauses(*rel)) { - if (!isA(clause)) { - continue; - } + // Translate subsumptive clause + Own rule = context->translateNonRecursiveClause(*clause, SubsumeDeleteCurrentCurrent); - // Translate subsumptive clause - Own rule = - context->translateNonRecursiveClause(*clause, SubsumeDeleteCurrentCurrent); - - // Add logging for subsumptive clause - if (glb->config().has("profile")) { - const std::string& relationName = toString(rel->getQualifiedName()); - const auto& srcLocation = clause->getSrcLoc(); - const std::string clauseText = stringify(toString(*clause)); - const std::string logTimerStatement = - LogStatement::tNonrecursiveRule(relationName, srcLocation, clauseText); - rule = mk(std::move(rule), logTimerStatement, mainRelation); - } + // Add logging for subsumptive clause + if (glb->config().has("profile")) { + const std::string& relationName = toString(rel.getQualifiedName()); + const auto& srcLocation = clause->getSrcLoc(); + const std::string clauseText = stringify(toString(*clause)); + const std::string logTimerStatement = + LogStatement::tNonrecursiveRule(relationName, srcLocation, clauseText); + rule = mk(std::move(rule), logTimerStatement, mainRelation); + } - // Add debug info for subsumptive clause - std::ostringstream ds; - ds << toString(*clause) << "\nin file "; - ds << clause->getSrcLoc(); - rule = mk(std::move(rule), ds.str()); + // Add debug info for subsumptive clause + std::ostringstream ds; + ds << toString(*clause) << "\nin file "; + ds << clause->getSrcLoc(); + rule = mk(std::move(rule), ds.str()); - // Add subsumptive rule to result - appendStmt(code, std::move(rule)); - } - appendStmt(code, mk(generateEraseTuples(rel, mainRelation, deleteRelation), - mk(deleteRelation))); + // Add subsumptive rule to result + appendStmt(code, std::move(rule)); } + appendStmt(code, mk(generateEraseTuples(&rel, mainRelation, deleteRelation), + mk(deleteRelation))); return mk(std::move(code)); } @@ -459,11 +470,16 @@ Own UnitTranslator::generateStratumPreamble(const ast::RelationS std::string deltaRelation = getDeltaRelationName(rel->getQualifiedName()); std::string mainRelation = getConcreteRelationName(rel->getQualifiedName()); appendStmt(preamble, generateNonRecursiveRelation(*rel)); + // lub tuples using the @lub relation + if (rel->getAuxiliaryArity() > 0) { + std::string newRelation = getNewRelationName(rel->getQualifiedName()); + appendStmt(preamble, generateStratumLubSequence(*rel, newRelation, mainRelation)); + appendStmt(preamble, mk(newRelation)); + } + // Generate non recursive delete sequences for subsumptive rules + appendStmt(preamble, generateNonRecursiveDelete(*rel)); } - // Generate non recursive delete sequences for subsumptive rules - appendStmt(preamble, generateNonRecursiveDelete(scc)); - // Generate code for priming relation for (const ast::Relation* rel : scc) { std::string deltaRelation = getDeltaRelationName(rel->getQualifiedName()); @@ -503,7 +519,11 @@ Own UnitTranslator::generateStratumTableUpdates(const ast::Relat // swap new and and delta relation and clear new relation afterwards (if not a subsumptive relation) Own updateRelTable; - if (!context->hasSubsumptiveClause(rel->getQualifiedName())) { + if (rel->getAuxiliaryArity() > 0) { + updateRelTable = mk(mk(deltaRelation), + generateStratumLubSequence(*rel, newRelation, deltaRelation), + generateMergeRelations(rel, mainRelation, deltaRelation)); + } else if (!context->hasSubsumptiveClause(rel->getQualifiedName())) { updateRelTable = mk(generateMergeRelations(rel, mainRelation, newRelation), mk(deltaRelation, newRelation), mk(newRelation)); } else { @@ -563,6 +583,120 @@ Own UnitTranslator::generateStratumLoopBody(const ast::RelationS return mk(std::move(loopBody)); } +/// assuming the @new() relation is populated with new tuples, generate RAM code +/// to populate the @delta() relation with the lubbed elements from @new() +Own UnitTranslator::generateStratumLubSequence( + const ast::Relation& rel, const std::string& fromName, const std::string& toName) const { + VecOwn stmts; + assert(rel.getAuxiliaryArity() > 0); + + auto attributes = rel.getAttributes(); + std::string name = getConcreteRelationName(rel.getQualifiedName()); + std::string lubName = getLubRelationName(rel.getQualifiedName()); + + const std::size_t arity = rel.getArity(); + const std::size_t auxiliaryArity = rel.getAuxiliaryArity(); + + // Step 1 : populate @lub() from @new() + VecOwn values; + + // index of the first auxiliary element of the relation + std::size_t firstAuxiliary = arity - auxiliaryArity; + + for (std::size_t i = 0; i < arity; i++) { + if (i >= firstAuxiliary) { + values.push_back(mk(i - firstAuxiliary + 1, 0)); + } else { + values.push_back(mk(0, i)); + } + } + Own op = mk(lubName, std::move(values)); + + for (std::size_t i = arity; i >= firstAuxiliary + 1; i--) { + const auto type = attributes[i - 1]->getTypeName(); + std::size_t level = i - firstAuxiliary; + auto aggregator = context->getLatticeTypeLubAggregator(type, mk(0, i - 1)); + Own condition = mk( + BinaryConstraintOp::NE, mk(level, i - 1), mk(0, i - 1)); + for (std::size_t j = 0; j < attributes.size(); j++) { + if (attributes[j]->getIsLattice()) break; + condition = mk(std::move(condition), + mk(BinaryConstraintOp::EQ, mk(level, j), + mk(0, j))); + } + op = mk(std::move(op), std::move(aggregator), fromName, + mk(level, i - 1), std::move(condition), level); + } + + op = mk(fromName, 0, std::move(op)); + appendStmt(stmts, mk(std::move(op))); + + // clear @new() now that we no longer need it + appendStmt(stmts, mk(fromName)); + + // Step 2 : populate @delta() from @lub() for tuples that have to be lubbed with @concrete + + Own condition; + for (std::size_t i = 0; i < arity; i++) { + if (i < firstAuxiliary) { + values.push_back(mk(0, i)); + } else { + assert(attributes[i]->getIsLattice()); + const auto type = attributes[i]->getTypeName(); + VecOwn args; + args.push_back(mk(0, i)); + args.push_back(mk(1, i)); + auto lub = context->getLatticeTypeLubFunctor(type, std::move(args)); + auto cst = mk(BinaryConstraintOp::EQ, mk(1, i), clone(lub)); + if (condition) { + condition = mk(std::move(condition), std::move(cst)); + } else { + condition = std::move(cst); + } + values.push_back(std::move(lub)); + } + } + op = mk(toName, std::move(values)); + op = mk(mk(std::move(condition)), std::move(op)); + + for (std::size_t i = 0; i < arity - auxiliaryArity; i++) { + auto cst = mk( + BinaryConstraintOp::EQ, mk(0, i), mk(1, i)); + if (condition) { + condition = mk(std::move(condition), std::move(cst)); + } else { + condition = std::move(cst); + } + } + if (condition) { + op = mk(std::move(condition), std::move(op)); + } + op = mk(name, 1, std::move(op)); + op = mk(lubName, 0, std::move(op)); + appendStmt(stmts, mk(std::move(op))); + + // Step 3 : populate @delta() from @lub() for tuples that have nothing to lub in @concrete + for (std::size_t i = 0; i < arity; i++) { + values.push_back(mk(0, i)); + } + op = mk(toName, std::move(values)); + for (std::size_t i = 0; i < arity; i++) { + if (i < firstAuxiliary) { + values.push_back(mk(0, i)); + } else { + values.push_back(mk()); + } + } + op = mk(mk(mk(name, std::move(values))), std::move(op)); + op = mk(lubName, 0, std::move(op)); + + appendStmt(stmts, mk(std::move(op))); + + appendStmt(stmts, mk(lubName)); + + return mk(std::move(stmts)); +} + Own UnitTranslator::generateStratumExitSequence(const ast::RelationSet& scc) const { // Helper function to add a new term to a conjunctive condition auto addCondition = [&](Own& cond, Own term) { @@ -688,6 +822,10 @@ Own UnitTranslator::generateStoreRelation(const ast::Relation* r Own UnitTranslator::createRamRelation( const ast::Relation* baseRelation, std::string ramRelationName) const { auto arity = baseRelation->getArity(); + + bool mergeAuxiliary = (ramRelationName != getNewRelationName(baseRelation->getQualifiedName())); + + auto auxArity = mergeAuxiliary ? baseRelation->getAuxiliaryArity() : 0; auto representation = baseRelation->getRepresentation(); if (representation == RelationRepresentation::BTREE_DELETE && ramRelationName[0] == '@') { representation = RelationRepresentation::DEFAULT; @@ -701,7 +839,7 @@ Own UnitTranslator::createRamRelation( } return mk( - ramRelationName, arity, 0, attributeNames, attributeTypeQualifiers, representation); + ramRelationName, arity, auxArity, attributeNames, attributeTypeQualifiers, representation); } VecOwn UnitTranslator::createRamRelations(const std::vector& sccOrdering) const { @@ -713,16 +851,24 @@ VecOwn UnitTranslator::createRamRelations(const std::vectorgetQualifiedName()); ramRelations.push_back(createRamRelation(rel, mainName)); + if (rel->getAuxiliaryArity() > 0) { + // Add lub relation + std::string lubName = getLubRelationName(rel->getQualifiedName()); + ramRelations.push_back(createRamRelation(rel, lubName)); + } + + if (isRecursive || rel->getAuxiliaryArity() > 0) { + // Add new relation + std::string newName = getNewRelationName(rel->getQualifiedName()); + ramRelations.push_back(createRamRelation(rel, newName)); + } + // Recursive relations also require @delta and @new variants, with the same signature if (isRecursive) { // Add delta relation std::string deltaName = getDeltaRelationName(rel->getQualifiedName()); ramRelations.push_back(createRamRelation(rel, deltaName)); - // Add new relation - std::string newName = getNewRelationName(rel->getQualifiedName()); - ramRelations.push_back(createRamRelation(rel, newName)); - // Add auxiliary relation for subsumption if (context->hasSubsumptiveClause(rel->getQualifiedName())) { // Add reject relation diff --git a/src/ast2ram/seminaive/UnitTranslator.h b/src/ast2ram/seminaive/UnitTranslator.h index 5c1400b2170..4499a1caa53 100644 --- a/src/ast2ram/seminaive/UnitTranslator.h +++ b/src/ast2ram/seminaive/UnitTranslator.h @@ -81,11 +81,13 @@ class UnitTranslator : public ast2ram::UnitTranslator { /** Low-level stratum translation */ Own generateStratum(std::size_t scc) const; Own generateStratumPreamble(const ast::RelationSet& scc) const; - Own generateNonRecursiveDelete(const ast::RelationSet& scc) const; + Own generateNonRecursiveDelete(const ast::Relation& rel) const; Own generateStratumPostamble(const ast::RelationSet& scc) const; Own generateStratumLoopBody(const ast::RelationSet& scc) const; Own generateStratumTableUpdates(const ast::RelationSet& scc) const; Own generateStratumExitSequence(const ast::RelationSet& scc) const; + Own generateStratumLubSequence( + const ast::Relation& rel, const std::string& fromName, const std::string& toName) const; /** Other helper generations */ virtual Own generateClearExpiredRelations(const ast::RelationSet& expiredRelations) const; diff --git a/src/ast2ram/utility/TranslatorContext.cpp b/src/ast2ram/utility/TranslatorContext.cpp index b7f677f3ef5..a9589307173 100644 --- a/src/ast2ram/utility/TranslatorContext.cpp +++ b/src/ast2ram/utility/TranslatorContext.cpp @@ -14,9 +14,12 @@ #include "ast2ram/utility/TranslatorContext.h" #include "Global.h" +#include "ast/Aggregator.h" #include "ast/Atom.h" #include "ast/BranchInit.h" #include "ast/Directive.h" +#include "ast/Functor.h" +#include "ast/IntrinsicFunctor.h" #include "ast/QualifiedName.h" #include "ast/SubsumptiveClause.h" #include "ast/TranslationUnit.h" @@ -39,11 +42,18 @@ #include "ast2ram/provenance/TranslationStrategy.h" #include "ast2ram/seminaive/TranslationStrategy.h" #include "ast2ram/utility/SipsMetric.h" +#include "ram/AbstractOperator.h" #include "ram/Condition.h" #include "ram/Expression.h" +#include "ram/IntrinsicAggregator.h" +#include "ram/IntrinsicOperator.h" #include "ram/Statement.h" +#include "ram/UndefValue.h" +#include "ram/UserDefinedAggregator.h" +#include "ram/UserDefinedOperator.h" #include "souffle/utility/FunctionalUtil.h" #include "souffle/utility/StringUtil.h" +#include #include namespace souffle::ast2ram { @@ -96,6 +106,11 @@ TranslatorContext::TranslatorContext(const ast::TranslationUnit& tu) { deltaRel[program->getRelation(delta.value())] = rel; } } + + // populates map type name -> lattice + for (const ast::Lattice* lattice : program->getLattices()) { + lattices.emplace(lattice->getQualifiedName(), lattice); + } } TranslatorContext::~TranslatorContext() = default; @@ -113,6 +128,40 @@ std::string TranslatorContext::getAttributeTypeQualifier(const ast::QualifiedNam return getTypeQualifier(typeEnv->getType(name)); } +Own TranslatorContext::getLatticeTypeLubFunctor( + const ast::QualifiedName& typeName, VecOwn args) const { + const ast::Lattice* lattice = lattices.at(typeName); + if (const auto* lub = as(lattice->getLub())) { + const auto typeAttributes = getFunctorParamTypeAtributes(*lub); + const auto returnAttribute = getFunctorReturnTypeAttribute(*lub); + bool stateful = isStatefulFunctor(*lub); + return mk( + lub->getName(), typeAttributes, returnAttribute, stateful, std::move(args)); + } else if (const auto* lub = as(lattice->getLub())) { + assert(false && lub && "intrinsic functors not yet supported in lattice"); + // return mk(getOverloadedFunctorOp(lub->getBaseFunctionOp()), + // std::move(args)); + } + assert(false); + return {}; +} + +Own TranslatorContext::getLatticeTypeLubAggregator( + const ast::QualifiedName& typeName, Own init) const { + const ast::Lattice* lattice = lattices.at(typeName); + if (const auto* lub = as(lattice->getLub())) { + const auto typeAttributes = getFunctorParamTypeAtributes(*lub); + const auto returnAttribute = getFunctorReturnTypeAttribute(*lub); + bool stateful = isStatefulFunctor(*lub); + return mk( + lub->getName(), std::move(init), typeAttributes, returnAttribute, stateful); + } else if (const auto* lub = as(lattice->getLub())) { + assert(false && lub && "intrinsic aggregators not yet supported in lattice"); + } + assert(false); + return {}; +} + std::size_t TranslatorContext::getNumberOfSCCs() const { return sccGraph->getNumberOfSCCs(); } diff --git a/src/ast2ram/utility/TranslatorContext.h b/src/ast2ram/utility/TranslatorContext.h index 0447068b7fd..0df1092586e 100644 --- a/src/ast2ram/utility/TranslatorContext.h +++ b/src/ast2ram/utility/TranslatorContext.h @@ -23,6 +23,7 @@ #include "ast/analysis/ProfileUse.h" #include "ast/analysis/typesystem/Type.h" #include "ast2ram/ClauseTranslator.h" +#include "ram/Aggregator.h" #include "souffle/BinaryConstraintOps.h" #include "souffle/TypeAttribute.h" #include "souffle/utility/ContainerUtil.h" @@ -96,6 +97,11 @@ class TranslatorContext { bool hasSizeLimit(const ast::Relation* relation) const; std::size_t getSizeLimit(const ast::Relation* relation) const; + Own getLatticeTypeLubAggregator( + const ast::QualifiedName& typeName, Own init) const; + Own getLatticeTypeLubFunctor( + const ast::QualifiedName& typeName, VecOwn args) const; + /** Associates a relation with its delta_debug relation if present */ const ast::Relation* getDeltaDebugRelation(const ast::Relation* rel) const; @@ -172,6 +178,7 @@ class TranslatorContext { Own sipsMetric; Own translationStrategy; std::map deltaRel; + std::map lattices; }; } // namespace souffle::ast2ram diff --git a/src/ast2ram/utility/Utils.cpp b/src/ast2ram/utility/Utils.cpp index 7ce90a3786d..d74ddd3a3b2 100644 --- a/src/ast2ram/utility/Utils.cpp +++ b/src/ast2ram/utility/Utils.cpp @@ -23,6 +23,7 @@ #include "ast/UnnamedVariable.h" #include "ast/Variable.h" #include "ast/utility/Utils.h" +#include "ast2ram/ClauseTranslator.h" #include "ast2ram/utility/Location.h" #include "ram/Clear.h" #include "ram/Condition.h" @@ -77,6 +78,9 @@ std::string getAtomName(const ast::Clause& clause, const ast::Atom* atom, } if (!isRecursive) { + if (mode == Auxiliary && clause.getHead() == atom) { + return getNewRelationName(atom->getQualifiedName()); + } return getConcreteRelationName(atom->getQualifiedName()); } if (clause.getHead() == atom) { @@ -100,6 +104,10 @@ std::string getNewRelationName(const ast::QualifiedName& name) { return getConcreteRelationName(name, "@new_"); } +std::string getLubRelationName(const ast::QualifiedName& name) { + return getConcreteRelationName(name, "@lub_"); +} + std::string getRejectRelationName(const ast::QualifiedName& name) { return getConcreteRelationName(name, "@reject_"); } diff --git a/src/ast2ram/utility/Utils.h b/src/ast2ram/utility/Utils.h index 65d1f5543f0..330488a40b2 100644 --- a/src/ast2ram/utility/Utils.h +++ b/src/ast2ram/utility/Utils.h @@ -54,6 +54,9 @@ std::string getDeltaRelationName(const ast::QualifiedName& name); /** Get the corresponding RAM 'new' relation name for the relation */ std::string getNewRelationName(const ast::QualifiedName& name); +/** Get the corresponding RAM 'lub' relation name for the relation */ +std::string getLubRelationName(const ast::QualifiedName& name); + /** Get the corresponding RAM 'reject' relation name for the relation */ std::string getRejectRelationName(const ast::QualifiedName& name); diff --git a/src/include/souffle/datastructure/BTree.h b/src/include/souffle/datastructure/BTree.h index 47d2bc8226d..72b4fc90105 100644 --- a/src/include/souffle/datastructure/BTree.h +++ b/src/include/souffle/datastructure/BTree.h @@ -91,8 +91,8 @@ class btree { /* -------------- updater utilities ------------- */ mutable Updater upd; - void update(Key& old_k, const Key& new_k) { - upd.update(old_k, new_k); + bool update(Key& old_k, const Key& new_k) { + return upd.update(old_k, new_k); } /* -------------- the node type ----------------- */ @@ -1225,14 +1225,14 @@ class btree { } // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *pos)) { + if (typeid(Comparator) != typeid(WeakComparator)) { if (!cur->lock.try_upgrade_to_write(cur_lease)) { // start again return insert(k, hints); } - update(*pos, k); + bool updated = update(*pos, k); cur->lock.end_write(); - return true; + return updated; } // we found the element => no check of lock necessary @@ -1280,14 +1280,14 @@ class btree { } // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *(pos - 1))) { + if (typeid(Comparator) != typeid(WeakComparator)) { if (!cur->lock.try_upgrade_to_write(cur_lease)) { // start again return insert(k, hints); } - update(*(pos - 1), k); + bool updated = update(*(pos - 1), k); cur->lock.end_write(); - return true; + return updated; } // we found the element => done @@ -1432,9 +1432,8 @@ class btree { // early exit for sets if (isSet && pos != b && weak_equal(*pos, k)) { // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *pos)) { - update(*pos, k); - return true; + if (typeid(Comparator) != typeid(WeakComparator)) { + return update(*pos, k); } return false; @@ -1458,9 +1457,8 @@ class btree { // early exit for sets if (isSet && pos != a && weak_equal(*(pos - 1), k)) { // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *(pos - 1))) { - update(*(pos - 1), k); - return true; + if (typeid(Comparator) != typeid(WeakComparator)) { + return update(*(pos - 1), k); } return false; diff --git a/src/include/souffle/datastructure/BTreeDelete.h b/src/include/souffle/datastructure/BTreeDelete.h index 745a3c12216..fc747ff1c46 100644 --- a/src/include/souffle/datastructure/BTreeDelete.h +++ b/src/include/souffle/datastructure/BTreeDelete.h @@ -92,8 +92,8 @@ class btree_delete { /* -------------- updater utilities ------------- */ mutable Updater upd; - void update(Key& old_k, const Key& new_k) { - upd.update(old_k, new_k); + bool update(Key& old_k, const Key& new_k) { + return upd.update(old_k, new_k); } /* -------------- the node type ----------------- */ @@ -1278,14 +1278,14 @@ class btree_delete { } // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *pos)) { + if (typeid(Comparator) != typeid(WeakComparator)) { if (!cur->lock.try_upgrade_to_write(cur_lease)) { // start again return insert(k, hints); } - update(*pos, k); + bool updated = update(*pos, k); cur->lock.end_write(); - return true; + return updated; } // we found the element => no check of lock necessary @@ -1333,14 +1333,14 @@ class btree_delete { } // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *(pos - 1))) { + if (typeid(Comparator) != typeid(WeakComparator)) { if (!cur->lock.try_upgrade_to_write(cur_lease)) { // start again return insert(k, hints); } - update(*(pos - 1), k); + bool updated = update(*(pos - 1), k); cur->lock.end_write(); - return true; + return updated; } // we found the element => done @@ -1485,9 +1485,8 @@ class btree_delete { // early exit for sets if (isSet && pos != b && weak_equal(*pos, k)) { // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *pos)) { - update(*pos, k); - return true; + if (typeid(Comparator) != typeid(WeakComparator)) { + return update(*pos, k); } return false; @@ -1511,9 +1510,8 @@ class btree_delete { // early exit for sets if (isSet && pos != a && weak_equal(*(pos - 1), k)) { // update provenance information - if (typeid(Comparator) != typeid(WeakComparator) && less(k, *(pos - 1))) { - update(*(pos - 1), k); - return true; + if (typeid(Comparator) != typeid(WeakComparator)) { + return update(*(pos - 1), k); } return false; diff --git a/src/include/souffle/datastructure/BTreeUtil.h b/src/include/souffle/datastructure/BTreeUtil.h index bdf624fa5ca..164006005db 100644 --- a/src/include/souffle/datastructure/BTreeUtil.h +++ b/src/include/souffle/datastructure/BTreeUtil.h @@ -217,7 +217,9 @@ struct default_strategy> : public linear {}; */ template struct updater { - void update(T& /* old_t */, const T& /* new_t */) {} + bool update(T& /* old_t */, const T& /* new_t */) { + return false; + } }; } // end of namespace detail diff --git a/src/interpreter/BTreeDeleteIndex.cpp b/src/interpreter/BTreeDeleteIndex.cpp index 1d51bbc172e..7275e3728e5 100644 --- a/src/interpreter/BTreeDeleteIndex.cpp +++ b/src/interpreter/BTreeDeleteIndex.cpp @@ -21,18 +21,15 @@ namespace souffle::interpreter { -#define CREATE_BTREE_DELETE_REL(Structure, Arity, ...) \ - case (Arity): { \ - return mk>(id.getAuxiliaryArity(), id.getName(), indexSelection); \ +#define CREATE_BTREE_DELETE_REL(Structure, Arity, AuxiliaryArity, ...) \ + if (id.getArity() == Arity && id.getAuxiliaryArity() == AuxiliaryArity) { \ + return mk>(id.getName(), indexSelection); \ } Own createBTreeDeleteRelation( const ram::Relation& id, const ram::analysis::IndexCluster& indexSelection) { - switch (id.getArity()) { - FOR_EACH_BTREE_DELETE(CREATE_BTREE_DELETE_REL); - - default: fatal("Requested arity not yet supported. Feel free to add it."); - } + FOR_EACH_BTREE_DELETE(CREATE_BTREE_DELETE_REL); + fatal("Requested arity not yet supported. Feel free to add it."); } } // namespace souffle::interpreter diff --git a/src/interpreter/BTreeIndex.cpp b/src/interpreter/BTreeIndex.cpp index 97127bf35f4..5a42aa6c491 100644 --- a/src/interpreter/BTreeIndex.cpp +++ b/src/interpreter/BTreeIndex.cpp @@ -21,19 +21,15 @@ namespace souffle::interpreter { -#define CREATE_BTREE_REL(Structure, Arity, ...) \ - case (Arity): { \ - return mk>( \ - id.getAuxiliaryArity(), id.getName(), indexSelection); \ +#define CREATE_BTREE_REL(Structure, Arity, AuxiliaryArity, ...) \ + if (id.getArity() == Arity && id.getAuxiliaryArity() == AuxiliaryArity) { \ + return mk>(id.getName(), indexSelection); \ } Own createBTreeRelation( const ram::Relation& id, const ram::analysis::IndexCluster& indexSelection) { - switch (id.getArity()) { - FOR_EACH_BTREE(CREATE_BTREE_REL); - - default: fatal("Requested arity not yet supported. Feel free to add it."); - } + FOR_EACH_BTREE(CREATE_BTREE_REL); + fatal("Requested arity not yet supported. Feel free to add it."); } } // namespace souffle::interpreter diff --git a/src/interpreter/Engine.cpp b/src/interpreter/Engine.cpp index 521f7f9fc8e..27b57854640 100644 --- a/src/interpreter/Engine.cpp +++ b/src/interpreter/Engine.cpp @@ -359,13 +359,13 @@ void Engine::createRelation(const ram::Relation& id, const std::size_t idx) { } RelationHandle res; - - if (id.getRepresentation() == RelationRepresentation::EQREL) { + bool hasProvenance = id.getArity() > 0 && id.getAttributeNames().back() == "@level_number"; + if (hasProvenance) { + res = createProvenanceRelation(id, isa.getIndexSelection(id.getName())); + } else if (id.getRepresentation() == RelationRepresentation::EQREL) { res = createEqrelRelation(id, isa.getIndexSelection(id.getName())); } else if (id.getRepresentation() == RelationRepresentation::BTREE_DELETE) { res = createBTreeDeleteRelation(id, isa.getIndexSelection(id.getName())); - } else if (id.getRepresentation() == RelationRepresentation::PROVENANCE) { - res = createProvenanceRelation(id, isa.getIndexSelection(id.getName())); } else { res = createBTreeRelation(id, isa.getIndexSelection(id.getName())); } @@ -527,9 +527,9 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { // Overload CASE based on number of arguments. // CASE(Kind) -> BASE_CASE(Kind) -// CASE(Kind, Structure, Arity) -> EXTEND_CASE(Kind, Structure, Arity) -#define GET_MACRO(_1, _2, _3, NAME, ...) NAME -#define CASE(...) GET_MACRO(__VA_ARGS__, EXTEND_CASE, _Dummy, BASE_CASE)(__VA_ARGS__) +// CASE(Kind, Structure, Arity, AuxiliaryArity) -> EXTEND_CASE(Kind, Structure, Arity, AuxiliaryArity) +#define GET_MACRO(_1, _2, _3, _4, NAME, ...) NAME +#define CASE(...) GET_MACRO(__VA_ARGS__, EXTEND_CASE, _Dummy, _Dummy2, BASE_CASE)(__VA_ARGS__) #define BASE_CASE(Kind) \ case (I_##Kind): { \ @@ -537,12 +537,12 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { [[maybe_unused]] const auto& shadow = *static_cast(node); \ [[maybe_unused]] const auto& cur = *static_cast(node->getShadow()); // EXTEND_CASE also defer the relation type -#define EXTEND_CASE(Kind, Structure, Arity) \ - case (I_##Kind##_##Structure##_##Arity): { \ +#define EXTEND_CASE(Kind, Structure, Arity, AuxiliaryArity) \ + case (I_##Kind##_##Structure##_##Arity##_##AuxiliaryArity): { \ return [&]() -> RamDomain { \ [[maybe_unused]] const auto& shadow = *static_cast(node); \ [[maybe_unused]] const auto& cur = *static_cast(node->getShadow());\ - using RelType = Relation; + using RelType = Relation; #define ESAC(Kind) \ } \ (); \ @@ -986,8 +986,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { return !execute(shadow.getChild(), ctxt); ESAC(Negation) -#define EMPTINESS_CHECK(Structure, Arity, ...) \ - CASE(EmptinessCheck, Structure, Arity) \ +#define EMPTINESS_CHECK(Structure, Arity, AuxiliaryArity, ...) \ + CASE(EmptinessCheck, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return rel.empty(); \ ESAC(EmptinessCheck) @@ -995,8 +995,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(EMPTINESS_CHECK) #undef EMPTINESS_CHECK -#define RELATION_SIZE(Structure, Arity, ...) \ - CASE(RelationSize, Structure, Arity) \ +#define RELATION_SIZE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(RelationSize, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return rel.size(); \ ESAC(RelationSize) @@ -1004,17 +1004,17 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(RELATION_SIZE) #undef RELATION_SIZE -#define EXISTENCE_CHECK(Structure, Arity, ...) \ - CASE(ExistenceCheck, Structure, Arity) \ - return evalExistenceCheck(shadow, ctxt); \ +#define EXISTENCE_CHECK(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ExistenceCheck, Structure, Arity, AuxiliaryArity) \ + return evalExistenceCheck(shadow, ctxt); \ ESAC(ExistenceCheck) FOR_EACH(EXISTENCE_CHECK) #undef EXISTENCE_CHECK -#define PROVENANCE_EXISTENCE_CHECK(Structure, Arity, ...) \ - CASE(ProvenanceExistenceCheck, Structure, Arity) \ - return evalProvenanceExistenceCheck(shadow, ctxt); \ +#define PROVENANCE_EXISTENCE_CHECK(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ProvenanceExistenceCheck, Structure, Arity, AuxiliaryArity) \ + return evalProvenanceExistenceCheck(shadow, ctxt); \ ESAC(ProvenanceExistenceCheck) FOR_EACH_PROVENANCE(PROVENANCE_EXISTENCE_CHECK) @@ -1135,8 +1135,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { return result; ESAC(TupleOperation) -#define SCAN(Structure, Arity, ...) \ - CASE(Scan, Structure, Arity) \ +#define SCAN(Structure, Arity, AuxiliaryArity, ...) \ + CASE(Scan, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalScan(rel, cur, shadow, ctxt); \ ESAC(Scan) @@ -1144,24 +1144,24 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(SCAN) #undef SCAN -#define PARALLEL_SCAN(Structure, Arity, ...) \ - CASE(ParallelScan, Structure, Arity) \ +#define PARALLEL_SCAN(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ParallelScan, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalParallelScan(rel, cur, shadow, ctxt); \ ESAC(ParallelScan) FOR_EACH(PARALLEL_SCAN) #undef PARALLEL_SCAN -#define INDEX_SCAN(Structure, Arity, ...) \ - CASE(IndexScan, Structure, Arity) \ +#define INDEX_SCAN(Structure, Arity, AuxiliaryArity, ...) \ + CASE(IndexScan, Structure, Arity, AuxiliaryArity) \ return evalIndexScan(cur, shadow, ctxt); \ ESAC(IndexScan) FOR_EACH(INDEX_SCAN) #undef INDEX_SCAN -#define PARALLEL_INDEX_SCAN(Structure, Arity, ...) \ - CASE(ParallelIndexScan, Structure, Arity) \ +#define PARALLEL_INDEX_SCAN(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ParallelIndexScan, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalParallelIndexScan(rel, cur, shadow, ctxt); \ ESAC(ParallelIndexScan) @@ -1169,8 +1169,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(PARALLEL_INDEX_SCAN) #undef PARALLEL_INDEX_SCAN -#define IFEXISTS(Structure, Arity, ...) \ - CASE(IfExists, Structure, Arity) \ +#define IFEXISTS(Structure, Arity, AuxiliaryArity, ...) \ + CASE(IfExists, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalIfExists(rel, cur, shadow, ctxt); \ ESAC(IfExists) @@ -1178,8 +1178,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(IFEXISTS) #undef IFEXISTS -#define PARALLEL_IFEXISTS(Structure, Arity, ...) \ - CASE(ParallelIfExists, Structure, Arity) \ +#define PARALLEL_IFEXISTS(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ParallelIfExists, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalParallelIfExists(rel, cur, shadow, ctxt); \ ESAC(ParallelIfExists) @@ -1187,16 +1187,16 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(PARALLEL_IFEXISTS) #undef PARALLEL_IFEXISTS -#define INDEX_IFEXISTS(Structure, Arity, ...) \ - CASE(IndexIfExists, Structure, Arity) \ +#define INDEX_IFEXISTS(Structure, Arity, AuxiliaryArity, ...) \ + CASE(IndexIfExists, Structure, Arity, AuxiliaryArity) \ return evalIndexIfExists(cur, shadow, ctxt); \ ESAC(IndexIfExists) FOR_EACH(INDEX_IFEXISTS) #undef INDEX_IFEXISTS -#define PARALLEL_INDEX_IFEXISTS(Structure, Arity, ...) \ - CASE(ParallelIndexIfExists, Structure, Arity) \ +#define PARALLEL_INDEX_IFEXISTS(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ParallelIndexIfExists, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalParallelIndexIfExists(rel, cur, shadow, ctxt); \ ESAC(ParallelIndexIfExists) @@ -1223,8 +1223,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { return execute(shadow.getNestedOperation(), ctxt); ESAC(UnpackRecord) -#define PARALLEL_AGGREGATE(Structure, Arity, ...) \ - CASE(ParallelAggregate, Structure, Arity) \ +#define PARALLEL_AGGREGATE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ParallelAggregate, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalParallelAggregate(rel, cur, shadow, ctxt); \ ESAC(ParallelAggregate) @@ -1232,8 +1232,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(PARALLEL_AGGREGATE) #undef PARALLEL_AGGREGATE -#define AGGREGATE(Structure, Arity, ...) \ - CASE(Aggregate, Structure, Arity) \ +#define AGGREGATE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(Aggregate, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalAggregate(cur, shadow, rel.scan(), ctxt); \ ESAC(Aggregate) @@ -1241,16 +1241,16 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(AGGREGATE) #undef AGGREGATE -#define PARALLEL_INDEX_AGGREGATE(Structure, Arity, ...) \ - CASE(ParallelIndexAggregate, Structure, Arity) \ - return evalParallelIndexAggregate(cur, shadow, ctxt); \ +#define PARALLEL_INDEX_AGGREGATE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(ParallelIndexAggregate, Structure, Arity, AuxiliaryArity) \ + return evalParallelIndexAggregate(cur, shadow, ctxt); \ ESAC(ParallelIndexAggregate) FOR_EACH(PARALLEL_INDEX_AGGREGATE) #undef PARALLEL_INDEX_AGGREGATE -#define INDEX_AGGREGATE(Structure, Arity, ...) \ - CASE(IndexAggregate, Structure, Arity) \ +#define INDEX_AGGREGATE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(IndexAggregate, Structure, Arity, AuxiliaryArity) \ return evalIndexAggregate(cur, shadow, ctxt); \ ESAC(IndexAggregate) @@ -1283,8 +1283,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { return result; ESAC(Filter) -#define GUARDED_INSERT(Structure, Arity, ...) \ - CASE(GuardedInsert, Structure, Arity) \ +#define GUARDED_INSERT(Structure, Arity, AuxiliaryArity, ...) \ + CASE(GuardedInsert, Structure, Arity, AuxiliaryArity) \ auto& rel = *static_cast(shadow.getRelation()); \ return evalGuardedInsert(rel, shadow, ctxt); \ ESAC(GuardedInsert) @@ -1292,8 +1292,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(GUARDED_INSERT) #undef GUARDED_INSERT -#define INSERT(Structure, Arity, ...) \ - CASE(Insert, Structure, Arity) \ +#define INSERT(Structure, Arity, AuxiliaryArity, ...) \ + CASE(Insert, Structure, Arity, AuxiliaryArity) \ auto& rel = *static_cast(shadow.getRelation()); \ return evalInsert(rel, shadow, ctxt); \ ESAC(Insert) @@ -1301,11 +1301,11 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(INSERT) #undef INSERT -#define ERASE(Structure, Arity, ...) \ - CASE(Erase, Structure, Arity) \ - void(static_cast(shadow.getRelation())); \ - auto& rel = *static_cast*>(shadow.getRelation()); \ - return evalErase(rel, shadow, ctxt); \ +#define ERASE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(Erase, Structure, Arity, AuxiliaryArity) \ + void(static_cast(shadow.getRelation())); \ + auto& rel = *static_cast*>(shadow.getRelation()); \ + return evalErase(rel, shadow, ctxt); \ ESAC(Erase) FOR_EACH_BTREE_DELETE(ERASE) @@ -1369,8 +1369,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { return execute(shadow.getChild(), ctxt); ESAC(DebugInfo) -#define CLEAR(Structure, Arity, ...) \ - CASE(Clear, Structure, Arity) \ +#define CLEAR(Structure, Arity, AuxiliaryArity, ...) \ + CASE(Clear, Structure, Arity, AuxiliaryArity) \ auto& rel = *static_cast(shadow.getRelation()); \ rel.__purge(); \ return true; \ @@ -1379,8 +1379,8 @@ RamDomain Engine::execute(const Node* node, Context& ctxt) { FOR_EACH(CLEAR) #undef CLEAR -#define ESTIMATEJOINSIZE(Structure, Arity, ...) \ - CASE(EstimateJoinSize, Structure, Arity) \ +#define ESTIMATEJOINSIZE(Structure, Arity, AuxiliaryArity, ...) \ + CASE(EstimateJoinSize, Structure, Arity, AuxiliaryArity) \ const auto& rel = *static_cast(shadow.getRelation()); \ return evalEstimateJoinSize(rel, cur, shadow, ctxt); \ ESAC(EstimateJoinSize) @@ -1945,7 +1945,7 @@ bool runNested(const ram::Aggregator& aggregator) { default: return false; } } else if (isA(aggregator)) { - return false; + return true; } return false; } diff --git a/src/interpreter/EqrelIndex.cpp b/src/interpreter/EqrelIndex.cpp index 7a368c16e9a..b6aeef17e54 100644 --- a/src/interpreter/EqrelIndex.cpp +++ b/src/interpreter/EqrelIndex.cpp @@ -23,7 +23,8 @@ namespace souffle::interpreter { Own createEqrelRelation( const ram::Relation& id, const ram::analysis::IndexCluster& indexSelection) { assert(id.getArity() == 2 && "Eqivalence relation must have arity size 2."); - return mk(id.getAuxiliaryArity(), id.getName(), indexSelection); + assert(id.getAuxiliaryArity() == 0 && "Equivalence relation must have auxiliary arity size 0."); + return mk(id.getName(), indexSelection); } } // namespace souffle::interpreter diff --git a/src/interpreter/Index.h b/src/interpreter/Index.h index 3fc693de8b4..b2a35fc42a3 100644 --- a/src/interpreter/Index.h +++ b/src/interpreter/Index.h @@ -142,11 +142,13 @@ struct ViewWrapper { /** * An index is an abstraction of a data structure */ -template typename Structure> +template typename Structure> class Index { public: static constexpr std::size_t Arity = _Arity; - using Data = Structure; + static constexpr std::size_t AuxiliaryArity = _AuxiliaryArity; + using Data = Structure; using Tuple = typename souffle::Tuple; using iterator = typename Data::iterator; using Hints = typename Data::operation_hints; @@ -239,7 +241,7 @@ class Index { /** * Inserts all elements of the given index. */ - void insert(const Index& src) { + void insert(const Index& src) { for (const auto& tuple : src) { this->insert(tuple); } @@ -316,8 +318,8 @@ class Index { * A partial specialize template for nullary indexes. * No complex data structure is required. */ -template