@@ -591,6 +591,28 @@ BuilderExpr BuilderExpr::singleValue() const {
591
591
return {builder_, agg, name, true };
592
592
}
593
593
594
+ BuilderExpr BuilderExpr::topK (int count) const {
595
+ if (!expr_->type ()->isNumber () && !expr_->type ()->isDateTime ()) {
596
+ throw InvalidQueryError () << " Unsupported type for topK aggregate: "
597
+ << expr_->type ()->toString ();
598
+ }
599
+ if (count == 0 ) {
600
+ throw InvalidQueryError ()
601
+ << " Non-zero integer argument is expected for topK aggregate. Provided: "
602
+ << expr_->type ()->toString ();
603
+ }
604
+ auto arr_type = ctx ().arrayVarLen (expr_->type (), 4 , false );
605
+ auto agg = makeExpr<AggExpr>(
606
+ arr_type, AggType::kTopK , expr_, false , builder_->cst (count).expr ());
607
+ auto top_str = (count > 0 ? " top_" : " bottom_" ) + std::to_string (std::abs (count));
608
+ auto name = name_.empty () ? top_str : name_ + " _" + top_str;
609
+ return {builder_, agg, name, true };
610
+ }
611
+
612
+ BuilderExpr BuilderExpr::bottomK (int count) const {
613
+ return topK (-count);
614
+ }
615
+
594
616
BuilderExpr BuilderExpr::stdDev () const {
595
617
if (!expr_->type ()->isNumber ()) {
596
618
throw InvalidQueryError () << " Non-numeric type " << expr_->type ()->toString ()
@@ -658,7 +680,7 @@ BuilderExpr BuilderExpr::lastValue() const {
658
680
return {builder_, expr, name, true };
659
681
}
660
682
661
- BuilderExpr BuilderExpr::agg (const std::string& agg_str, const BuilderExpr& arg) const {
683
+ BuilderExpr BuilderExpr::agg (const std::string& agg_str, BuilderExpr arg) const {
662
684
static const std::unordered_map<std::string, AggType> agg_names = {
663
685
{" count" , AggType::kCount },
664
686
{" count_dist" , AggType::kCount },
@@ -678,6 +700,10 @@ BuilderExpr BuilderExpr::agg(const std::string& agg_str, const BuilderExpr& arg)
678
700
{" sample" , AggType::kSample },
679
701
{" single_value" , AggType::kSingleValue },
680
702
{" single value" , AggType::kSingleValue },
703
+ {" topk" , AggType::kTopK },
704
+ {" top_k" , AggType::kTopK },
705
+ {" bottomk" , AggType::kTopK },
706
+ {" bottom_k" , AggType::kTopK },
681
707
{" stddev" , AggType::kStdDevSamp },
682
708
{" stddev_samp" , AggType::kStdDevSamp },
683
709
{" stddev samp" , AggType::kStdDevSamp },
@@ -693,6 +719,14 @@ BuilderExpr BuilderExpr::agg(const std::string& agg_str, const BuilderExpr& arg)
693
719
if (kind == AggType::kApproxQuantile && !arg.expr ()) {
694
720
throw InvalidQueryError (" Missing argument for approximate quantile aggregate." );
695
721
}
722
+ if (kind == AggType::kTopK ) {
723
+ if (!arg.expr ()) {
724
+ throw InvalidQueryError (" Missing argument for topK aggregate." );
725
+ }
726
+ if (agg_str_lower == " bottomk" || agg_str_lower == " bottom_k" ) {
727
+ arg = arg.uminus ();
728
+ }
729
+ }
696
730
if (kind == AggType::kCorr && !arg.expr ()) {
697
731
throw InvalidQueryError (" Missing argument for corr aggregate." );
698
732
}
@@ -709,6 +743,10 @@ BuilderExpr BuilderExpr::agg(const std::string& agg_str, double val) const {
709
743
return agg (agg_str, arg);
710
744
}
711
745
746
+ BuilderExpr BuilderExpr::agg (const std::string& agg_str, int val) const {
747
+ return agg (agg_str, builder_->cst (val));
748
+ }
749
+
712
750
BuilderExpr BuilderExpr::agg (AggType agg_kind, const BuilderExpr& arg) const {
713
751
return agg (agg_kind, false , arg);
714
752
}
@@ -717,14 +755,19 @@ BuilderExpr BuilderExpr::agg(AggType agg_kind, double val) const {
717
755
return agg (agg_kind, false , val);
718
756
}
719
757
758
+ BuilderExpr BuilderExpr::agg (AggType agg_kind, int val) const {
759
+ return agg (agg_kind, false , builder_->cst (val));
760
+ }
761
+
720
762
BuilderExpr BuilderExpr::agg (AggType agg_kind,
721
763
bool is_distinct,
722
764
const BuilderExpr& arg) const {
723
765
if (is_distinct && agg_kind != AggType::kCount ) {
724
766
throw InvalidQueryError () << " Distinct property cannot be set to true for "
725
767
<< agg_kind << " aggregate." ;
726
768
}
727
- if (arg.expr () && agg_kind != AggType::kApproxQuantile && agg_kind != AggType::kCorr ) {
769
+ if (arg.expr () && agg_kind != AggType::kApproxQuantile && agg_kind != AggType::kCorr &&
770
+ agg_kind != AggType::kTopK ) {
728
771
throw InvalidQueryError () << " Aggregate argument is supported for approximate "
729
772
" quantile and corr only but provided for "
730
773
<< agg_kind;
@@ -736,6 +779,13 @@ BuilderExpr BuilderExpr::agg(AggType agg_kind,
736
779
<< arg.expr ()->toString ();
737
780
}
738
781
}
782
+ if (agg_kind == AggType::kTopK ) {
783
+ if (!arg.expr ()->is <Constant>() || !arg.type ()->isInteger ()) {
784
+ throw InvalidQueryError ()
785
+ << " Expected integer constant argumnt for topK. Provided: "
786
+ << arg.expr ()->toString ();
787
+ }
788
+ }
739
789
740
790
switch (agg_kind) {
741
791
case AggType::kAvg :
@@ -756,6 +806,8 @@ BuilderExpr BuilderExpr::agg(AggType agg_kind,
756
806
return sample ();
757
807
case AggType::kSingleValue :
758
808
return singleValue ();
809
+ case AggType::kTopK :
810
+ return topK (arg.expr ()->as <Constant>()->intVal ());
759
811
case AggType::kStdDevSamp :
760
812
return stdDev ();
761
813
case AggType::kCorr :
@@ -2053,21 +2105,28 @@ BuilderExpr BuilderNode::parseAggString(const std::string& agg_str) const {
2053
2105
auto val_str = boost::trim_copy (
2054
2106
col_name.substr (comma_pos + 1 , col_name.size () - comma_pos - 1 ));
2055
2107
char * end = nullptr ;
2056
- auto val = std::strtod (val_str.c_str (), &end);
2057
- // Require value string to be fully interpreted to avoid silent errors like
2058
- // 1..1 interpreted as 1.
2059
- if (val == HUGE_VAL || end == val_str.c_str () ||
2060
- end != (val_str.c_str () + val_str.size ())) {
2061
- // If value is not decimal then assume it is a column name (for corr aggregate).
2062
- auto ref = getRefByName (node_.get (), val_str, true );
2063
- if (!ref) {
2064
- throw InvalidQueryError ()
2065
- << " Cannot parse aggregate parameter (decimal or column name expected): "
2066
- << val_str;
2067
- }
2068
- arg = BuilderExpr (builder_, ref, val_str);
2108
+
2109
+ // First, try to parse arg as an integer.
2110
+ auto long_val = std::strtol (val_str.c_str (), &end, 10 );
2111
+ if (end == (val_str.c_str () + val_str.size ())) {
2112
+ arg = builder_->cst (long_val);
2069
2113
} else {
2070
- arg = builder_->cst (val);
2114
+ // Now try to parse as a double. Require value string to be fully interpreted
2115
+ // to avoid silent errors like 1..1 interpreted as 1.
2116
+ auto val = std::strtod (val_str.c_str (), &end);
2117
+ if (val == HUGE_VAL || end == val_str.c_str () ||
2118
+ end != (val_str.c_str () + val_str.size ())) {
2119
+ // If value is not decimal then assume it is a column name (for corr aggregate).
2120
+ auto ref = getRefByName (node_.get (), val_str, true );
2121
+ if (!ref) {
2122
+ throw InvalidQueryError ()
2123
+ << " Cannot parse aggregate parameter (decimal or column name expected): "
2124
+ << val_str;
2125
+ }
2126
+ arg = BuilderExpr (builder_, ref, val_str);
2127
+ } else {
2128
+ arg = builder_->cst (val);
2129
+ }
2071
2130
}
2072
2131
col_name = boost::trim_copy (col_name.substr (0 , comma_pos));
2073
2132
}
0 commit comments