1414
1515#include " ast/transform/MaterializeAggregationQueries.h"
1616#include " AggregateOp.h"
17+ #include " FunctorOps.h"
1718#include " ast/Aggregator.h"
1819#include " ast/Argument.h"
1920#include " ast/Atom.h"
@@ -277,6 +278,10 @@ bool MaterializeAggregationQueriesTransformer::materializeAggregationQueries(
277278
278279 for (auto && cl : program.getClauses ()) {
279280 auto & clause = *cl;
281+
282+ // the mapping from the Arguments in the original clause to their type(s)
283+ std::optional<std::map<const Argument*, analysis::TypeSet>> clauseArgTypes;
284+
280285 visit (clause, [&](Aggregator& agg) {
281286 if (!needsMaterializedRelation (agg)) {
282287 return ;
@@ -285,6 +290,12 @@ bool MaterializeAggregationQueriesTransformer::materializeAggregationQueries(
285290 if (innerAggregates.find (&agg) != innerAggregates.end ()) {
286291 return ;
287292 }
293+
294+ // compute types before the clause gets modified
295+ if (!clauseArgTypes) {
296+ clauseArgTypes = analysis::TypeAnalysis::analyseTypes (translationUnit, clause);
297+ }
298+
288299 // begin materialisation process
289300 auto aggregateBodyRelationName = analysis::findUniqueRelationName (program, " __agg_subclause" );
290301 // quickly copy in all the literals from the aggregate body
@@ -306,17 +317,28 @@ bool MaterializeAggregationQueriesTransformer::materializeAggregationQueries(
306317 for (const auto & variableName : headArguments) {
307318 aggClauseHead->addArgument (mk<Variable>(variableName));
308319 }
309- // add them to the relation as well (need to do a bit of type analysis to make this work)
320+
321+ // add them to the relation as well
310322 auto aggRel = mk<Relation>(QualifiedName::fromString (aggregateBodyRelationName));
311- std::map<const Argument*, analysis::TypeSet> argTypes =
312- analysis::TypeAnalysis::analyseTypes (translationUnit, *aggClause);
313323
314324 for (const auto & cur : aggClauseHead->getArguments ()) {
315- // cur will point us to a particular argument
316- // that is found in the aggClause
317- auto const curArgType = argTypes[cur];
318- assert (!curArgType.empty () && " unexpected empty typeset" );
319- aggRel->addAttribute (mk<Attribute>(toString (*cur), curArgType.begin ()->getName ()));
325+ // Find type of argument variable in original clause
326+ auto it = std::find_if (clauseArgTypes->cbegin (), clauseArgTypes->cend (),
327+ [&](const std::pair<const Argument*, analysis::TypeSet>& pair) -> bool {
328+ if (const Variable* var = as<Variable>(pair.first )) {
329+ // use type from first variable matching the name
330+ return (var->getName () == toString (*cur));
331+ }
332+ return false ;
333+ });
334+ assert (it != clauseArgTypes->cend () && " unexpected unknown argument" );
335+
336+ if (it != clauseArgTypes->cend ()) {
337+ auto const curArgType = it->second ;
338+ assert (!curArgType.empty () && " unexpected empty typeset" );
339+ assert (curArgType.size () == 1 && " expected fully resolved type" );
340+ aggRel->addAttribute (mk<Attribute>(toString (*cur), curArgType.begin ()->getName ()));
341+ }
320342 }
321343
322344 // Set up the aggregate body atom that will represent the materialised relation we just created
@@ -388,6 +410,19 @@ bool MaterializeAggregationQueriesTransformer::needsMaterializedRelation(const A
388410 return true ;
389411 }
390412
413+ // If we have a multi-result intrinsic functor within this aggregate => materialize
414+ bool seenMultiresultIntrinsicFunctor = false ;
415+ visit (agg, [&](const IntrinsicFunctor& intFunc) {
416+ auto candidates = functorBuiltIn (intFunc.getBaseFunctionOp ());
417+ seenMultiresultIntrinsicFunctor |= std::any_of (candidates.cbegin (), candidates.cend (),
418+ [](const std::reference_wrapper<const IntrinsicFunctorInfo>& info) -> bool {
419+ return isFunctorMultiResult (info.get ().op );
420+ });
421+ });
422+ if (seenMultiresultIntrinsicFunctor) {
423+ return true ;
424+ }
425+
391426 // If the same variable occurs several times => materialize
392427 bool duplicates = false ;
393428 std::set<std::string> vars;
0 commit comments