Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit e5cd6ca

Browse files
committed
Add TopK/BottomK aggregate.
Signed-off-by: ienkovich <[email protected]>
1 parent 590bce6 commit e5cd6ca

39 files changed

+1433
-60
lines changed

omniscidb/IR/Expr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ class AggExpr : public Expr {
791791
: Expr(type, true), agg_type_(a), arg_(arg), is_distinct_(d), arg1_(arg1) {
792792
if (arg1) {
793793
if (agg_type_ == AggType::kApproxCountDistinct ||
794-
agg_type_ == AggType::kApproxQuantile) {
794+
agg_type_ == AggType::kApproxQuantile || agg_type_ == AggType::kTopK) {
795795
CHECK(arg1_->is<Constant>());
796796
} else {
797797
CHECK(agg_type_ == AggType::kCorr);

omniscidb/IR/OpType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ enum class AggType {
8787
kApproxQuantile,
8888
kSample,
8989
kSingleValue,
90+
kTopK,
9091
// Compound aggregates
9192
kStdDevSamp,
9293
kCorr,
@@ -205,6 +206,8 @@ inline std::string toString(hdk::ir::AggType agg) {
205206
return "SAMPLE";
206207
case hdk::ir::AggType::kSingleValue:
207208
return "SINGLE_VALUE";
209+
case hdk::ir::AggType::kTopK:
210+
return "TOP_K";
208211
case hdk::ir::AggType::kStdDevSamp:
209212
return "STDDEV";
210213
case hdk::ir::AggType::kCorr:

omniscidb/QueryBuilder/QueryBuilder.cpp

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,28 @@ BuilderExpr BuilderExpr::singleValue() const {
591591
return {builder_, agg, name, true};
592592
}
593593

594+
BuilderExpr BuilderExpr::topK(int count) const {
595+
if (!expr_->type()->isNumber() && !expr_->type()->isDateTime()) {
596+
throw InvalidQueryError() << "Unsupported type for topK aggregate: "
597+
<< expr_->type()->toString();
598+
}
599+
if (count == 0) {
600+
throw InvalidQueryError()
601+
<< "Non-zero integer argument is expected for topK aggregate. Provided: "
602+
<< expr_->type()->toString();
603+
}
604+
auto arr_type = ctx().arrayVarLen(expr_->type(), 4, false);
605+
auto agg = makeExpr<AggExpr>(
606+
arr_type, AggType::kTopK, expr_, false, builder_->cst(count).expr());
607+
auto top_str = (count > 0 ? "top_" : "bottom_") + std::to_string(std::abs(count));
608+
auto name = name_.empty() ? top_str : name_ + "_" + top_str;
609+
return {builder_, agg, name, true};
610+
}
611+
612+
BuilderExpr BuilderExpr::bottomK(int count) const {
613+
return topK(-count);
614+
}
615+
594616
BuilderExpr BuilderExpr::stdDev() const {
595617
if (!expr_->type()->isNumber()) {
596618
throw InvalidQueryError() << "Non-numeric type " << expr_->type()->toString()
@@ -658,7 +680,7 @@ BuilderExpr BuilderExpr::lastValue() const {
658680
return {builder_, expr, name, true};
659681
}
660682

661-
BuilderExpr BuilderExpr::agg(const std::string& agg_str, const BuilderExpr& arg) const {
683+
BuilderExpr BuilderExpr::agg(const std::string& agg_str, BuilderExpr arg) const {
662684
static const std::unordered_map<std::string, AggType> agg_names = {
663685
{"count", AggType::kCount},
664686
{"count_dist", AggType::kCount},
@@ -678,6 +700,10 @@ BuilderExpr BuilderExpr::agg(const std::string& agg_str, const BuilderExpr& arg)
678700
{"sample", AggType::kSample},
679701
{"single_value", AggType::kSingleValue},
680702
{"single value", AggType::kSingleValue},
703+
{"topk", AggType::kTopK},
704+
{"top_k", AggType::kTopK},
705+
{"bottomk", AggType::kTopK},
706+
{"bottom_k", AggType::kTopK},
681707
{"stddev", AggType::kStdDevSamp},
682708
{"stddev_samp", AggType::kStdDevSamp},
683709
{"stddev samp", AggType::kStdDevSamp},
@@ -693,6 +719,14 @@ BuilderExpr BuilderExpr::agg(const std::string& agg_str, const BuilderExpr& arg)
693719
if (kind == AggType::kApproxQuantile && !arg.expr()) {
694720
throw InvalidQueryError("Missing argument for approximate quantile aggregate.");
695721
}
722+
if (kind == AggType::kTopK) {
723+
if (!arg.expr()) {
724+
throw InvalidQueryError("Missing argument for topK aggregate.");
725+
}
726+
if (agg_str_lower == "bottomk" || agg_str_lower == "bottom_k") {
727+
arg = arg.uminus();
728+
}
729+
}
696730
if (kind == AggType::kCorr && !arg.expr()) {
697731
throw InvalidQueryError("Missing argument for corr aggregate.");
698732
}
@@ -709,6 +743,10 @@ BuilderExpr BuilderExpr::agg(const std::string& agg_str, double val) const {
709743
return agg(agg_str, arg);
710744
}
711745

746+
BuilderExpr BuilderExpr::agg(const std::string& agg_str, int val) const {
747+
return agg(agg_str, builder_->cst(val));
748+
}
749+
712750
BuilderExpr BuilderExpr::agg(AggType agg_kind, const BuilderExpr& arg) const {
713751
return agg(agg_kind, false, arg);
714752
}
@@ -717,14 +755,19 @@ BuilderExpr BuilderExpr::agg(AggType agg_kind, double val) const {
717755
return agg(agg_kind, false, val);
718756
}
719757

758+
BuilderExpr BuilderExpr::agg(AggType agg_kind, int val) const {
759+
return agg(agg_kind, false, builder_->cst(val));
760+
}
761+
720762
BuilderExpr BuilderExpr::agg(AggType agg_kind,
721763
bool is_distinct,
722764
const BuilderExpr& arg) const {
723765
if (is_distinct && agg_kind != AggType::kCount) {
724766
throw InvalidQueryError() << "Distinct property cannot be set to true for "
725767
<< agg_kind << " aggregate.";
726768
}
727-
if (arg.expr() && agg_kind != AggType::kApproxQuantile && agg_kind != AggType::kCorr) {
769+
if (arg.expr() && agg_kind != AggType::kApproxQuantile && agg_kind != AggType::kCorr &&
770+
agg_kind != AggType::kTopK) {
728771
throw InvalidQueryError() << "Aggregate argument is supported for approximate "
729772
"quantile and corr only but provided for "
730773
<< agg_kind;
@@ -736,6 +779,13 @@ BuilderExpr BuilderExpr::agg(AggType agg_kind,
736779
<< arg.expr()->toString();
737780
}
738781
}
782+
if (agg_kind == AggType::kTopK) {
783+
if (!arg.expr()->is<Constant>() || !arg.type()->isInteger()) {
784+
throw InvalidQueryError()
785+
<< "Expected integer constant argumnt for topK. Provided: "
786+
<< arg.expr()->toString();
787+
}
788+
}
739789

740790
switch (agg_kind) {
741791
case AggType::kAvg:
@@ -756,6 +806,8 @@ BuilderExpr BuilderExpr::agg(AggType agg_kind,
756806
return sample();
757807
case AggType::kSingleValue:
758808
return singleValue();
809+
case AggType::kTopK:
810+
return topK(arg.expr()->as<Constant>()->intVal());
759811
case AggType::kStdDevSamp:
760812
return stdDev();
761813
case AggType::kCorr:
@@ -2053,21 +2105,28 @@ BuilderExpr BuilderNode::parseAggString(const std::string& agg_str) const {
20532105
auto val_str = boost::trim_copy(
20542106
col_name.substr(comma_pos + 1, col_name.size() - comma_pos - 1));
20552107
char* end = nullptr;
2056-
auto val = std::strtod(val_str.c_str(), &end);
2057-
// Require value string to be fully interpreted to avoid silent errors like
2058-
// 1..1 interpreted as 1.
2059-
if (val == HUGE_VAL || end == val_str.c_str() ||
2060-
end != (val_str.c_str() + val_str.size())) {
2061-
// If value is not decimal then assume it is a column name (for corr aggregate).
2062-
auto ref = getRefByName(node_.get(), val_str, true);
2063-
if (!ref) {
2064-
throw InvalidQueryError()
2065-
<< "Cannot parse aggregate parameter (decimal or column name expected): "
2066-
<< val_str;
2067-
}
2068-
arg = BuilderExpr(builder_, ref, val_str);
2108+
2109+
// First, try to parse arg as an integer.
2110+
auto long_val = std::strtol(val_str.c_str(), &end, 10);
2111+
if (end == (val_str.c_str() + val_str.size())) {
2112+
arg = builder_->cst(long_val);
20692113
} else {
2070-
arg = builder_->cst(val);
2114+
// Now try to parse as a double. Require value string to be fully interpreted
2115+
// to avoid silent errors like 1..1 interpreted as 1.
2116+
auto val = std::strtod(val_str.c_str(), &end);
2117+
if (val == HUGE_VAL || end == val_str.c_str() ||
2118+
end != (val_str.c_str() + val_str.size())) {
2119+
// If value is not decimal then assume it is a column name (for corr aggregate).
2120+
auto ref = getRefByName(node_.get(), val_str, true);
2121+
if (!ref) {
2122+
throw InvalidQueryError()
2123+
<< "Cannot parse aggregate parameter (decimal or column name expected): "
2124+
<< val_str;
2125+
}
2126+
arg = BuilderExpr(builder_, ref, val_str);
2127+
} else {
2128+
arg = builder_->cst(val);
2129+
}
20712130
}
20722131
col_name = boost::trim_copy(col_name.substr(0, comma_pos));
20732132
}

omniscidb/QueryBuilder/QueryBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class BuilderExpr {
7676
BuilderExpr approxQuantile(double val) const;
7777
BuilderExpr sample() const;
7878
BuilderExpr singleValue() const;
79+
BuilderExpr topK(int count) const;
80+
BuilderExpr bottomK(int count) const;
7981
BuilderExpr stdDev() const;
8082
BuilderExpr corr(const BuilderExpr& arg) const;
8183

@@ -84,10 +86,12 @@ class BuilderExpr {
8486
BuilderExpr firstValue() const;
8587
BuilderExpr lastValue() const;
8688

87-
BuilderExpr agg(const std::string& agg_str, const BuilderExpr& arg) const;
89+
BuilderExpr agg(const std::string& agg_str, BuilderExpr arg) const;
8890
BuilderExpr agg(const std::string& agg_str, double val = HUGE_VAL) const;
91+
BuilderExpr agg(const std::string& agg_str, int val) const;
8992
BuilderExpr agg(AggType agg_kind, const BuilderExpr& arg) const;
9093
BuilderExpr agg(AggType agg_kind, double val) const;
94+
BuilderExpr agg(AggType agg_kind, int val) const;
9195
BuilderExpr agg(AggType agg_kind, bool is_dinstinct, const BuilderExpr& arg) const;
9296
BuilderExpr agg(AggType agg_kind,
9397
bool is_dinstinct = false,

omniscidb/QueryEngine/Execute.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,10 @@ quantile::TDigest* RowSetMemoryOwner::nullTDigest(double const q) {
509509
.get();
510510
}
511511

512+
int8_t* RowSetMemoryOwner::topKBuffer(size_t size) {
513+
return allocate(size);
514+
}
515+
512516
bool Executor::isCPUOnly() const {
513517
CHECK(data_mgr_);
514518
return !data_mgr_->getCudaMgr();
@@ -2033,6 +2037,7 @@ void Executor::allocateShuffleBuffers(
20332037
getConfigPtr(),
20342038
query_infos,
20352039
false /*approx_quantile*/,
2040+
false /*topk_agg*/,
20362041
false /*allow_multifrag*/,
20372042
false /*keyless_hash*/,
20382043
false /*interleaved_bins_on_gpu*/,

omniscidb/QueryEngine/GroupByRuntime.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "JoinHashTable/Runtime/JoinHashImpl.h"
1818
#include "MurmurHash.h"
19+
#include "TopKAggRuntime.h"
1920

2021
extern "C" RUNTIME_EXPORT ALWAYS_INLINE DEVICE uint32_t
2122
key_hash(GENERIC_ADDR_SPACE const int64_t* key,
@@ -376,3 +377,32 @@ DEF_TRANSLATE_NULL_KEY(int32_t)
376377
DEF_TRANSLATE_NULL_KEY(int64_t)
377378

378379
#undef DEF_TRANSLATE_NULL_KEY
380+
381+
#define DEF_AGG_TOPK(val_type, suffix) \
382+
extern "C" RUNTIME_EXPORT ALWAYS_INLINE DEVICE void agg_topk_##suffix( \
383+
int64_t* agg, val_type val, val_type empty_val, int k, bool inline_buffer) { \
384+
agg_topk_impl<val_type>(agg, val, empty_val, k, inline_buffer); \
385+
}
386+
387+
#define DEF_AGG_TOPK_SKIP_VAL(val_type, suffix) \
388+
extern "C" RUNTIME_EXPORT ALWAYS_INLINE DEVICE void agg_topk_##suffix##_skip_val( \
389+
int64_t* agg, val_type val, val_type skip_val, int k, bool inline_buffer) { \
390+
if (val != skip_val) { \
391+
agg_topk_##suffix(agg, val, skip_val, k, inline_buffer); \
392+
} \
393+
}
394+
395+
#define DEF_AGG_TOPK_ALL(val_type, suffix) \
396+
DEF_AGG_TOPK(val_type, suffix) \
397+
DEF_AGG_TOPK_SKIP_VAL(val_type, suffix)
398+
399+
DEF_AGG_TOPK_ALL(int8_t, int8)
400+
DEF_AGG_TOPK_ALL(int16_t, int16)
401+
DEF_AGG_TOPK_ALL(int32_t, int32)
402+
DEF_AGG_TOPK_ALL(int64_t, int64)
403+
DEF_AGG_TOPK_ALL(float, float)
404+
DEF_AGG_TOPK_ALL(double, double)
405+
406+
#undef DEF_AGG_TOPK_ALL
407+
#undef DEF_AGG_TOPK_SKIP_VAL
408+
#undef DEF_AGG_TOPK

omniscidb/QueryEngine/MemoryLayoutBuilder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ std::unique_ptr<QueryMemoryDescriptor> build_query_memory_descriptor(
831831
executor->getConfigPtr(),
832832
query_infos,
833833
false,
834+
false,
834835
allow_multifrag,
835836
false,
836837
false,
@@ -968,10 +969,12 @@ std::unique_ptr<QueryMemoryDescriptor> build_query_memory_descriptor(
968969

969970
auto approx_quantile =
970971
anyOf(ra_exe_unit.target_exprs, hdk::ir::AggType::kApproxQuantile);
972+
auto topk_agg = anyOf(ra_exe_unit.target_exprs, hdk::ir::AggType::kTopK);
971973
return std::make_unique<QueryMemoryDescriptor>(executor->getDataMgr(),
972974
executor->getConfigPtr(),
973975
query_infos,
974976
approx_quantile,
977+
topk_agg,
975978
allow_multifrag,
976979
keyless_hash,
977980
interleaved_bins_on_gpu,

omniscidb/QueryEngine/NativeCodegen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,9 @@ std::vector<std::string> get_agg_fnames(
778778
case hdk::ir::AggType::kApproxQuantile:
779779
result.emplace_back("agg_approx_quantile");
780780
break;
781+
case hdk::ir::AggType::kTopK:
782+
result.emplace_back("agg_topk");
783+
break;
781784
default:
782785
CHECK(false);
783786
}

omniscidb/QueryEngine/OutputBufferInitialization.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ int64_t get_agg_initial_val(hdk::ir::AggType agg,
162162
case hdk::ir::AggType::kAvg:
163163
case hdk::ir::AggType::kCount:
164164
case hdk::ir::AggType::kApproxCountDistinct:
165+
case hdk::ir::AggType::kTopK:
165166
return 0;
166167
case hdk::ir::AggType::kApproxQuantile:
167168
return {}; // Init value is a quantile::TDigest* set elsewhere.

0 commit comments

Comments
 (0)