Skip to content

Commit fe61624

Browse files
committed
fix: extend recursive protection to prevent stack overflows in additional functions
1 parent aa1e6da commit fe61624

25 files changed

+44
-0
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ recursive_protection = [
7676
"datafusion-common/recursive_protection",
7777
"datafusion-expr/recursive_protection",
7878
"datafusion-optimizer/recursive_protection",
79+
"datafusion-physical-expr/recursive_protection",
7980
"datafusion-physical-optimizer/recursive_protection",
8081
"datafusion-sql/recursive_protection",
8182
"sqlparser/recursive-protection",

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2092,6 +2092,7 @@ impl Normalizeable for Expr {
20922092
}
20932093

20942094
impl NormalizeEq for Expr {
2095+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
20952096
fn normalize_eq(&self, other: &Self) -> bool {
20962097
match (self, other) {
20972098
(

datafusion/expr/src/logical_plan/invariants.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
7171
///
7272
/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
7373
/// for more details of user-provided extension node invariants.
74+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
7475
fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
7576
plan.apply_with_subqueries(|plan: &LogicalPlan| {
7677
if let LogicalPlan::Extension(Extension { node }) = plan {
@@ -372,6 +373,7 @@ fn check_aggregation_in_scalar_subquery(
372373
Ok(())
373374
}
374375

376+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
375377
fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
376378
match inner_plan {
377379
LogicalPlan::Projection(projection) => {

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,6 +2278,7 @@ impl Filter {
22782278
Self::try_new_internal(predicate, input)
22792279
}
22802280

2281+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
22812282
fn is_allowed_filter_type(data_type: &DataType) -> bool {
22822283
match data_type {
22832284
// Interpret NULL as a missing boolean value.

datafusion/expr/src/utils.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,7 @@ pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
934934
split_conjunction_impl(expr, vec![])
935935
}
936936

937+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
937938
fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
938939
match expr {
939940
Expr::BinaryExpr(BinaryExpr {
@@ -1051,6 +1052,7 @@ pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
10511052
split_binary_owned_impl(expr, op, vec![])
10521053
}
10531054

1055+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
10541056
fn split_binary_owned_impl(
10551057
expr: Expr,
10561058
operator: Operator,
@@ -1078,6 +1080,7 @@ pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
10781080
split_binary_impl(expr, op, vec![])
10791081
}
10801082

1083+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
10811084
fn split_binary_impl<'a>(
10821085
expr: &'a Expr,
10831086
operator: Operator,

datafusion/optimizer/src/decorrelate.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool {
445445
}
446446
}
447447

448+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
448449
fn collect_local_correlated_cols(
449450
plan: &LogicalPlan,
450451
all_cols_map: &HashMap<LogicalPlan, BTreeSet<Column>>,

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
5555
true
5656
}
5757

58+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
5859
fn rewrite(
5960
&self,
6061
plan: LogicalPlan,

datafusion/optimizer/src/eliminate_cross_join.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ fn rewrite_children(
228228
/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
229229
/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
230230
/// possible_join_keys
231+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
231232
fn flatten_join_inputs(
232233
plan: LogicalPlan,
233234
possible_join_keys: &mut JoinKeySet,
@@ -264,6 +265,7 @@ fn flatten_join_inputs(
264265
/// `flatten_join_inputs`
265266
///
266267
/// Must stay in sync with `flatten_join_inputs`
268+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
267269
fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
268270
// can only flatten inner / cross joins
269271
match plan {
@@ -368,6 +370,7 @@ fn find_inner_join(
368370
}
369371

370372
/// Extract join keys from a WHERE clause
373+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
371374
fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
372375
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
373376
match op {
@@ -399,6 +402,7 @@ fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
399402
/// # Returns
400403
/// * `Some()` when there are few remaining predicates in filter_expr
401404
/// * `None` otherwise
405+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
402406
fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
403407
match expr {
404408
Expr::BinaryExpr(BinaryExpr {

datafusion/optimizer/src/eliminate_group_by_constant.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ impl OptimizerRule for EliminateGroupByConstant {
9595
///
9696
/// Intended to be used only within this rule, helper function, which heavily
9797
/// relies on `SimplifyExpressions` result.
98+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
9899
fn is_constant_expression(expr: &Expr) -> bool {
99100
match expr {
100101
Expr::Alias(e) => is_constant_expression(&e.expr),

datafusion/optimizer/src/eliminate_limit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ impl OptimizerRule for EliminateLimit {
5353
true
5454
}
5555

56+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
5657
fn rewrite(
5758
&self,
5859
plan: LogicalPlan,

datafusion/optimizer/src/eliminate_outer_join.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ pub fn eliminate_outer(
172172
/// For IS NOT NULL/NOT expr, always returns false for NULL input.
173173
/// extracts columns from these exprs.
174174
/// For all other exprs, fall through
175+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
175176
fn extract_non_nullable_columns(
176177
expr: &Expr,
177178
non_nullable_cols: &mut Vec<Column>,

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ fn extract_or_clauses_for_join<'a>(
366366
/// Otherwise, return None.
367367
///
368368
/// For other clause, apply the rule above to extract clause.
369+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
369370
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
370371
let mut predicate = None;
371372

@@ -764,6 +765,7 @@ impl OptimizerRule for PushDownFilter {
764765
true
765766
}
766767

768+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
767769
fn rewrite(
768770
&self,
769771
plan: LogicalPlan,

datafusion/optimizer/src/push_down_limit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl OptimizerRule for PushDownLimit {
4848
true
4949
}
5050

51+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
5152
fn rewrite(
5253
&self,
5354
plan: LogicalPlan,

datafusion/optimizer/src/scalar_subquery_to_join.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ impl OptimizerRule for ScalarSubqueryToJoin {
7474
true
7575
}
7676

77+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
7778
fn rewrite(
7879
&self,
7980
plan: LogicalPlan,

datafusion/optimizer/src/simplify_expressions/utils.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pub static POWS_OF_TEN: [i128; 38] = [
6767

6868
/// returns true if `needle` is found in a chain of search_op
6969
/// expressions. Such as: (A AND B) AND C
70+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
7071
fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
7172
match expr {
7273
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
@@ -86,6 +87,7 @@ pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
8687
/// expressions. Such as: A ^ (A ^ (B ^ A))
8788
pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
8889
/// Deletes recursively 'needles' in a chain of xor expressions
90+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
8991
fn recursive_delete_xor_in_expr(
9092
expr: &Expr,
9193
needle: &Expr,
@@ -266,6 +268,7 @@ pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
266268
/// For Between, not (A between B and C) ===> (A not between B and C)
267269
/// not (A not between B and C) ===> (A between B and C)
268270
/// For others, use Not clause
271+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
269272
pub fn negate_clause(expr: Expr) -> Expr {
270273
match expr {
271274
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
@@ -335,6 +338,7 @@ pub fn negate_clause(expr: Expr) -> Expr {
335338
/// For Negative:
336339
/// ~(~A) ===> A
337340
/// For others, use Negative clause
341+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
338342
pub fn distribute_negation(expr: Expr) -> Expr {
339343
match expr {
340344
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {

datafusion/physical-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ workspace = true
3737
[lib]
3838
name = "datafusion_physical_expr"
3939

40+
[features]
41+
recursive_protection = ["dep:recursive"]
42+
4043
[dependencies]
4144
ahash = { workspace = true }
4245
arrow = { workspace = true }
@@ -52,6 +55,7 @@ itertools = { workspace = true, features = ["use_std"] }
5255
log = { workspace = true }
5356
paste = "^1.0"
5457
petgraph = "0.8.2"
58+
recursive = { workspace = true, optional = true }
5559

5660
[dev-dependencies]
5761
arrow = { workspace = true, features = ["test_utils"] }

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ impl PhysicalExpr for BinaryExpr {
343343
self
344344
}
345345

346+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
346347
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
347348
BinaryTypeCoercer::new(
348349
&self.left.data_type(input_schema)?,
@@ -356,6 +357,7 @@ impl PhysicalExpr for BinaryExpr {
356357
Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
357358
}
358359

360+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
359361
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
360362
use arrow::compute::kernels::numeric::*;
361363

@@ -648,6 +650,7 @@ impl PhysicalExpr for BinaryExpr {
648650
}
649651
}
650652

653+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
651654
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
652655
fn write_child(
653656
f: &mut std::fmt::Formatter,

datafusion/physical-expr/src/intervals/utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use datafusion_expr::Operator;
3535
/// We do not support every type of [`Operator`]s either. Over time, this check
3636
/// will relax as more types of `PhysicalExpr`s and `Operator`s are supported.
3737
/// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported.
38+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
3839
pub fn check_support(expr: &Arc<dyn PhysicalExpr>, schema: &SchemaRef) -> bool {
3940
let expr_any = expr.as_any();
4041
if let Some(binary_expr) = expr_any.downcast_ref::<BinaryExpr>() {

datafusion/physical-expr/src/planner.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ use datafusion_expr::{
105105
/// * `e` - The logical expression
106106
/// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references
107107
/// to qualified or unqualified fields by name.
108+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
108109
pub fn create_physical_expr(
109110
e: &Expr,
110111
input_dfschema: &DFSchema,
@@ -385,6 +386,7 @@ pub fn create_physical_expr(
385386
}
386387

387388
/// Create vector of Physical Expression from a vector of logical expression
389+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
388390
pub fn create_physical_exprs<'a, I>(
389391
exprs: I,
390392
input_dfschema: &DFSchema,

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ pub fn reorder_join_keys_to_inputs(
691691
}
692692

693693
/// Reorder the current join keys ordering based on either left partition or right partition
694+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
694695
fn reorder_current_join_keys(
695696
join_keys: JoinKeyPairs,
696697
left_partition: Option<&Partitioning>,
@@ -1011,6 +1012,7 @@ fn remove_dist_changing_operators(
10111012
/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2",
10121013
/// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet",
10131014
/// ```
1015+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
10141016
pub fn replace_order_preserving_variants(
10151017
mut context: DistributionContext,
10161018
) -> Result<DistributionContext> {

datafusion/physical-optimizer/src/filter_pushdown.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ enum ParentPredicateStates {
428428
Supported,
429429
}
430430

431+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
431432
fn push_down_filters(
432433
node: Arc<dyn ExecutionPlan>,
433434
parent_predicates: Vec<Arc<dyn PhysicalExpr>>,

datafusion/physical-optimizer/src/limit_pushdown.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ pub fn pushdown_limit_helper(
262262
}
263263

264264
/// Pushes down the limit through the plan.
265+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
265266
pub(crate) fn pushdown_limits(
266267
pushdown_plan: Arc<dyn ExecutionPlan>,
267268
global_state: GlobalRequirements,

datafusion/substrait/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ itertools = { workspace = true }
3939
object_store = { workspace = true }
4040
pbjson-types = { workspace = true }
4141
prost = { workspace = true }
42+
recursive = { workspace = true, optional = true }
4243
substrait = { version = "0.57", features = ["serde"] }
4344
url = { workspace = true }
4445
tokio = { workspace = true, features = ["fs"] }
@@ -54,6 +55,7 @@ insta = { workspace = true }
5455
default = ["physical"]
5556
physical = ["datafusion/parquet"]
5657
protoc = ["substrait/protoc"]
58+
recursive_protection = ["dep:recursive", "datafusion/recursive_protection"]
5759

5860
[package.metadata.docs.rs]
5961
# Use default features ("physical") for docs, plus "protoc". "protoc" is needed

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ fn arg_list_to_binary_op_tree(op: Operator, mut args: Vec<Expr>) -> Result<Expr>
128128
///
129129
/// `take_len` represents the number of elements to take from `args` before returning.
130130
/// We use `take_len` to avoid recursively building a `Take<Take<Take<...>>>` type.
131+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
131132
fn arg_list_to_binary_op_tree_inner(
132133
op: Operator,
133134
args: &mut Drain<Expr>,

0 commit comments

Comments
 (0)