Skip to content

Commit 13151cc

Browse files
committed
move grouping udf
1 parent 3373cd0 commit 13151cc

File tree

7 files changed

+94
-45
lines changed

7 files changed

+94
-45
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions/src/core/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ pub mod expr_ext;
2727
pub mod getfield;
2828
pub mod greatest;
2929
mod greatest_least_utils;
30-
pub mod grouping;
3130
pub mod least;
3231
pub mod named_struct;
3332
pub mod nullif;
@@ -56,7 +55,6 @@ make_udf_function!(least::LeastFunc, least);
5655
make_udf_function!(union_extract::UnionExtractFun, union_extract);
5756
make_udf_function!(union_tag::UnionTagFunc, union_tag);
5857
make_udf_function!(version::VersionFunc, version);
59-
make_udf_function!(grouping::GroupingFunc, grouping);
6058

6159
pub mod expr_fn {
6260
use datafusion_expr::{Expr, Literal};

datafusion/optimizer/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ arrow = { workspace = true }
4545
chrono = { workspace = true }
4646
datafusion-common = { workspace = true, default-features = true }
4747
datafusion-expr = { workspace = true }
48-
datafusion-functions = { workspace = true }
4948
datafusion-physical-expr = { workspace = true }
5049
indexmap = { workspace = true }
5150
itertools = { workspace = true }

datafusion/functions/src/core/grouping.rs renamed to datafusion/optimizer/src/analyzer/grouping.rs

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ use datafusion_expr::{
2626
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
2727
Volatility,
2828
};
29-
use datafusion_macros::user_doc;
3029
use std::any::Any;
31-
use std::sync::Arc;
30+
use std::sync::{Arc, LazyLock};
3231

33-
use crate::utils::make_scalar_function;
32+
pub(crate) fn grouping_func() -> Arc<datafusion_expr::ScalarUDF> {
33+
static INSTANCE: LazyLock<Arc<datafusion_expr::ScalarUDF>> =
34+
LazyLock::new(|| Arc::new(datafusion_expr::ScalarUDF::from(GroupingFunc::new())));
35+
Arc::clone(&INSTANCE)
36+
}
3437

3538
macro_rules! grouping_id {
3639
($grouping_id:expr, $indices:expr, $type:ty, $array_type:ty) => {{
@@ -69,28 +72,6 @@ macro_rules! grouping_id {
6972
}};
7073
}
7174

72-
#[user_doc(
73-
doc_section(label = "Other Functions"),
74-
description = "Developer API: Returns the level of grouping, equals to (((grouping_id >> array[0]) & 1) << (n-1)) + (((grouping_id >> array[1]) & 1) << (n-2)) + ... + (((grouping_id >> array[n-1]) & 1) << 0). Returns grouping_id if indices is not provided.",
75-
syntax_example = "grouping(grouping_id[, indices])",
76-
sql_example = r#"```sql
77-
> SELECT grouping(__grouping_id, make_array(0)) FROM table GROUP BY GROUPING SETS ((a), (b));
78-
+----------------+
79-
| grouping |
80-
+----------------+
81-
| 1 |
82-
| 0 |
83-
+----------------+
84-
```"#,
85-
argument(
86-
name = "grouping_id",
87-
description = "The internal grouping ID column (UInt8/16/32/64)"
88-
),
89-
argument(
90-
name = "indices",
91-
description = "The indices of the column in the grouping set (Int32)"
92-
)
93-
)]
9475
#[derive(Debug)]
9576
pub struct GroupingFunc {
9677
signature: Signature,
@@ -128,7 +109,22 @@ impl ScalarUDFImpl for GroupingFunc {
128109
}
129110

130111
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
131-
make_scalar_function(grouping_inner, vec![])(&args.args)
112+
let mut array_args = Vec::new();
113+
for arg in &args.args {
114+
match arg {
115+
ColumnarValue::Array(array) => {
116+
array_args.push(Arc::clone(array));
117+
}
118+
ColumnarValue::Scalar(scalar) => {
119+
// Convert scalar to array by repeating the value
120+
let array = scalar.to_array_of_size(args.number_rows)?;
121+
array_args.push(array);
122+
}
123+
}
124+
}
125+
126+
let result = grouping_inner(&array_args)?;
127+
Ok(ColumnarValue::Array(result))
132128
}
133129

134130
fn short_circuits(&self) -> bool {
@@ -175,7 +171,7 @@ impl ScalarUDFImpl for GroupingFunc {
175171
}
176172

177173
fn documentation(&self) -> Option<&Documentation> {
178-
self.doc()
174+
None
179175
}
180176
}
181177

datafusion/optimizer/src/analyzer/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ pub mod function_rewrite;
3838
pub mod resolve_grouping_function;
3939
pub mod type_coercion;
4040

41+
mod grouping;
42+
4143
pub mod subquery {
4244
#[deprecated(
4345
since = "44.0.0",

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
use std::collections::HashMap;
2222
use std::sync::Arc;
2323

24+
use crate::analyzer::grouping::grouping_func;
2425
use crate::analyzer::AnalyzerRule;
2526

2627
use arrow::array::ListArray;
2728
use arrow::datatypes::Int32Type;
2829
use datafusion_common::config::ConfigOptions;
2930
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3031
use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue};
31-
use datafusion_expr::expr::{AggregateFunction, Alias};
32+
use datafusion_expr::expr::{AggregateFunction, Alias, ScalarFunction};
3233
use datafusion_expr::logical_plan::LogicalPlan;
3334
use datafusion_expr::utils::grouping_set_to_exprlist;
3435
use datafusion_expr::{Aggregate, Expr, Projection};
35-
use datafusion_functions::core::grouping;
3636
use itertools::Itertools;
3737

3838
/// Replaces grouping aggregation function with value derived from internal grouping id
@@ -198,7 +198,10 @@ fn grouping_function_on_id(
198198
.enumerate()
199199
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
200200
{
201-
return Ok(grouping().call(vec![grouping_id_column]));
201+
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
202+
grouping_func(),
203+
vec![grouping_id_column],
204+
)));
202205
}
203206

204207
let args = args
@@ -216,5 +219,8 @@ fn grouping_function_on_id(
216219
))),
217220
None,
218221
);
219-
Ok(grouping().call(vec![grouping_id_column, indices]))
222+
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
223+
grouping_func(),
224+
vec![grouping_id_column, indices],
225+
)))
220226
}

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use datafusion_common::{
2121
assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
2222
TableReference,
2323
};
24+
use datafusion_expr::expr::ScalarFunction;
2425
use datafusion_expr::test::function_stub::{
2526
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
2627
};
@@ -29,7 +30,6 @@ use datafusion_expr::{
2930
LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode,
3031
UserDefinedLogicalNodeCore,
3132
};
32-
use datafusion_functions::core::grouping;
3333
use datafusion_functions::unicode;
3434
use datafusion_functions_aggregate::grouping::grouping_udaf;
3535
use datafusion_functions_nested::make_array::make_array_udf;
@@ -2567,6 +2567,47 @@ fn test_not_ilike_filter_with_escape() {
25672567

25682568
#[test]
25692569
fn test_grouping() -> Result<()> {
2570+
#[derive(Debug)]
2571+
struct MockGroupingFunc {
2572+
signature: Signature,
2573+
}
2574+
2575+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2576+
use std::any::Any;
2577+
2578+
impl MockGroupingFunc {
2579+
fn new() -> Self {
2580+
Self {
2581+
signature: Signature::user_defined(Volatility::Immutable),
2582+
}
2583+
}
2584+
}
2585+
2586+
impl ScalarUDFImpl for MockGroupingFunc {
2587+
fn as_any(&self) -> &dyn Any {
2588+
self
2589+
}
2590+
2591+
fn name(&self) -> &str {
2592+
"grouping"
2593+
}
2594+
2595+
fn signature(&self) -> &Signature {
2596+
&self.signature
2597+
}
2598+
2599+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
2600+
Ok(DataType::Int32)
2601+
}
2602+
2603+
fn invoke_with_args(
2604+
&self,
2605+
_args: datafusion_expr::ScalarFunctionArgs,
2606+
) -> Result<datafusion_expr::ColumnarValue> {
2607+
todo!()
2608+
}
2609+
}
2610+
25702611
let schema = Schema::new(vec![
25712612
Field::new("c1", DataType::Int32, false),
25722613
Field::new("c2", DataType::Int32, false),
@@ -2587,6 +2628,8 @@ fn test_grouping() -> Result<()> {
25872628
)?
25882629
.build()?;
25892630

2631+
let udf = Arc::new(datafusion_expr::ScalarUDF::from(MockGroupingFunc::new()));
2632+
25902633
let group1 =
25912634
ScalarValue::List(Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(
25922635
vec![Some(vec![Some(0)])],
@@ -2601,15 +2644,21 @@ fn test_grouping() -> Result<()> {
26012644
)));
26022645
let project = LogicalPlanBuilder::from(plan)
26032646
.project(vec![
2604-
grouping()
2605-
.call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)])
2606-
.alias("grouping(test.c1)"),
2607-
grouping()
2608-
.call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)])
2609-
.alias("grouping(test.c2)"),
2610-
grouping()
2611-
.call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)])
2612-
.alias("grouping(test.c1,test.c2)"),
2647+
Expr::ScalarFunction(ScalarFunction::new_udf(
2648+
Arc::clone(&udf),
2649+
vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)],
2650+
))
2651+
.alias("grouping(test.c1)"),
2652+
Expr::ScalarFunction(ScalarFunction::new_udf(
2653+
Arc::clone(&udf),
2654+
vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)],
2655+
))
2656+
.alias("grouping(test.c2)"),
2657+
Expr::ScalarFunction(ScalarFunction::new_udf(
2658+
Arc::clone(&udf),
2659+
vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)],
2660+
))
2661+
.alias("grouping(test.c1,test.c2)"),
26132662
])?
26142663
.build()?;
26152664
let unparser = Unparser::new(&UnparserPostgreSqlDialect {});

0 commit comments

Comments
 (0)