diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 00bf83aba..f07e2f9ea 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -16,12 +16,15 @@ #[cfg(not(feature = "std"))] use alloc::{boxed::Box, string::String, vec::Vec}; use core::fmt; +use core::ops::ControlFlow; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::ast::value::escape_single_quote_string; -use crate::ast::{display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName}; +use crate::ast::{ + display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName, Visit, Visitor, +}; use crate::tokenizer::Token; /// An `ALTER TABLE` (`Statement::AlterTable`) operation @@ -200,6 +203,35 @@ impl fmt::Display for AlterTableOperation { } } +impl Visit for AlterTableOperation { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + AlterTableOperation::DropConstraint { .. } + | AlterTableOperation::DropColumn { .. } + | AlterTableOperation::DropPrimaryKey + | AlterTableOperation::RenameColumn { .. } + | AlterTableOperation::RenameTable { .. } + | AlterTableOperation::RenameConstraint { .. } => ControlFlow::Continue(()), + AlterTableOperation::AddConstraint(c) => c.visit(visitor), + AlterTableOperation::AddColumn { column_def, .. } => column_def.visit(visitor), + AlterTableOperation::RenamePartitions { + old_partitions, + new_partitions, + } => { + old_partitions.visit(visitor)?; + new_partitions.visit(visitor) + } + AlterTableOperation::AddPartitions { new_partitions, .. } => { + new_partitions.visit(visitor) + } + AlterTableOperation::DropPartitions { partitions, .. } => partitions.visit(visitor), + AlterTableOperation::ChangeColumn { options, .. } => options.visit(visitor), + + AlterTableOperation::AlterColumn { op, .. } => op.visit(visitor), + } + } +} + /// An `ALTER COLUMN` (`Statement::AlterTable`) operation #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -242,6 +274,18 @@ impl fmt::Display for AlterColumnOperation { } } +impl Visit for AlterColumnOperation { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + AlterColumnOperation::SetNotNull + | AlterColumnOperation::DropNotNull + | AlterColumnOperation::DropDefault => ControlFlow::Continue(()), + AlterColumnOperation::SetDefault { value } => value.visit(visitor), + AlterColumnOperation::SetDataType { using, .. } => using.visit(visitor), + } + } +} + /// A table-level constraint, specified in a `CREATE TABLE` or an /// `ALTER TABLE ADD ` statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -400,6 +444,18 @@ impl fmt::Display for TableConstraint { } } +impl Visit for TableConstraint { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + TableConstraint::Unique { .. } + | TableConstraint::ForeignKey { .. } + | TableConstraint::Index { .. } + | TableConstraint::FulltextOrSpatial { .. } => ControlFlow::Continue(()), + TableConstraint::Check { expr, .. } => expr.visit(visitor), + } + } +} + /// Representation whether a definition can can contains the KEY or INDEX keywords with the same /// meaning. /// @@ -479,6 +535,12 @@ impl fmt::Display for ColumnDef { } } +impl Visit for ColumnDef { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.options.visit(visitor) + } +} + /// An optionally-named `ColumnOption`: `[ CONSTRAINT ] `. /// /// Note that implementations are substantially more permissive than the ANSI @@ -508,6 +570,12 @@ impl fmt::Display for ColumnOptionDef { } } +impl Visit for ColumnOptionDef { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.option.visit(visitor) + } +} + /// `ColumnOption`s are modifiers that follow a column definition in a `CREATE /// TABLE` statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -580,6 +648,21 @@ impl fmt::Display for ColumnOption { } } +impl Visit for ColumnOption { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + ColumnOption::Null + | ColumnOption::NotNull + | ColumnOption::Unique { .. } + | ColumnOption::ForeignKey { .. } + | ColumnOption::DialectSpecific(_) + | ColumnOption::CharacterSet(_) + | ColumnOption::Comment(_) => ControlFlow::Continue(()), + ColumnOption::Default(e) | ColumnOption::Check(e) => e.visit(visitor), + } + } +} + fn display_constraint_name(name: &'_ Option) -> impl fmt::Display + '_ { struct ConstraintName<'a>(&'a Option); impl<'a> fmt::Display for ConstraintName<'a> { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7f3d15438..b3354bdf9 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -18,6 +18,7 @@ use alloc::{ vec::Vec, }; use core::fmt; +use core::ops::ControlFlow; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -37,6 +38,7 @@ pub use self::query::{ Values, WildcardAdditionalOptions, With, }; pub use self::value::{escape_quoted_string, DateTimeField, TrimWhereField, Value}; +pub use visitor::*; mod data_type; mod ddl; @@ -44,6 +46,7 @@ pub mod helpers; mod operator; mod query; mod value; +mod visitor; struct DisplaySeparated<'a, T> where @@ -176,6 +179,12 @@ impl fmt::Display for Array { } } +impl Visit for Array { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.elem.visit(visitor) + } +} + /// JsonOperator #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -858,6 +867,147 @@ impl fmt::Display for Expr { } } +impl Visit for Expr { + fn visit(&self, visitor: &mut V) -> ControlFlow { + visitor.visit_expr(self)?; + match self { + Expr::Identifier(_) + | Expr::CompoundIdentifier(_) + | Expr::TypedString { .. } + | Expr::MatchAgainst { .. } + | Expr::Value(_) => ControlFlow::Continue(()), + Expr::JsonAccess { left, right, .. } => { + left.visit(visitor)?; + right.visit(visitor) + } + Expr::CompositeAccess { expr, .. } => expr.visit(visitor), + Expr::IsFalse(e) + | Expr::IsNotFalse(e) + | Expr::IsTrue(e) + | Expr::IsNotTrue(e) + | Expr::IsNull(e) + | Expr::IsNotNull(e) + | Expr::IsUnknown(e) + | Expr::IsNotUnknown(e) => e.visit(visitor), + Expr::IsDistinctFrom(l, r) | Expr::IsNotDistinctFrom(l, r) => { + l.visit(visitor)?; + r.visit(visitor) + } + Expr::InList { expr, list, .. } => { + expr.visit(visitor)?; + list.visit(visitor) + } + Expr::InSubquery { expr, subquery, .. } => { + expr.visit(visitor)?; + subquery.visit(visitor) + } + Expr::InUnnest { + expr, array_expr, .. + } => { + expr.visit(visitor)?; + array_expr.visit(visitor) + } + Expr::Between { + expr, low, high, .. + } => { + expr.visit(visitor)?; + low.visit(visitor)?; + high.visit(visitor) + } + Expr::BinaryOp { left, right, .. } => { + left.visit(visitor)?; + right.visit(visitor) + } + Expr::Like { expr, pattern, .. } + | Expr::ILike { expr, pattern, .. } + | Expr::SimilarTo { expr, pattern, .. } => { + expr.visit(visitor)?; + pattern.visit(visitor) + } + Expr::AtTimeZone { timestamp, .. } => timestamp.visit(visitor), + Expr::AnyOp(expr) + | Expr::AllOp(expr) + | Expr::UnaryOp { expr, .. } + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::SafeCast { expr, .. } + | Expr::Extract { expr, .. } + | Expr::Ceil { expr, .. } + | Expr::Floor { expr, .. } => expr.visit(visitor), + Expr::Position { expr, r#in } => { + expr.visit(visitor)?; + r#in.visit(visitor) + } + Expr::Substring { + expr, + substring_from, + substring_for, + } => { + expr.visit(visitor)?; + substring_from.visit(visitor)?; + substring_for.visit(visitor) + } + Expr::Trim { + expr, + trim_where, + trim_what, + } => { + expr.visit(visitor)?; + trim_where.visit(visitor)?; + trim_what.visit(visitor) + } + Expr::Overlay { + expr, + overlay_what, + overlay_from, + overlay_for, + } => { + expr.visit(visitor)?; + overlay_what.visit(visitor)?; + overlay_from.visit(visitor)?; + overlay_for.visit(visitor) + } + Expr::Collate { expr, .. } => expr.visit(visitor), + Expr::Nested(e) => e.visit(visitor), + Expr::MapAccess { column, keys } => { + column.visit(visitor)?; + keys.visit(visitor) + } + Expr::Function(f) => f.visit(visitor), + Expr::AggregateExpressionWithFilter { expr, filter } => { + expr.visit(visitor)?; + filter.visit(visitor) + } + Expr::Case { + operand, + conditions, + results, + else_result, + } => { + operand.visit(visitor)?; + conditions.visit(visitor)?; + results.visit(visitor)?; + else_result.visit(visitor) + } + Expr::Exists { subquery, .. } => subquery.visit(visitor), + Expr::Subquery(query) => query.visit(visitor), + Expr::ArraySubquery(query) => query.visit(visitor), + Expr::ListAgg(list) => list.visit(visitor), + Expr::ArrayAgg(array) => array.visit(visitor), + Expr::GroupingSets(exprs) => exprs.visit(visitor), + Expr::Cube(exprs) => exprs.visit(visitor), + Expr::Rollup(exprs) => exprs.visit(visitor), + Expr::Tuple(exprs) => exprs.visit(visitor), + Expr::ArrayIndex { obj, indexes } => { + obj.visit(visitor)?; + indexes.visit(visitor) + } + Expr::Array(array) => array.visit(visitor), + Expr::Interval { value, .. } => value.visit(visitor), + } + } +} + /// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`) #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -899,6 +1049,14 @@ impl fmt::Display for WindowSpec { } } +impl Visit for WindowSpec { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.partition_by.visit(visitor)?; + self.order_by.visit(visitor)?; + self.window_frame.visit(visitor) + } +} + /// Specifies the data processed by a window function, e.g. /// `RANGE UNBOUNDED PRECEDING` or `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`. /// @@ -929,6 +1087,13 @@ impl Default for WindowFrame { } } +impl Visit for WindowFrame { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.start_bound.visit(visitor)?; + self.end_bound.visit(visitor) + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum WindowFrameUnits { @@ -971,6 +1136,15 @@ impl fmt::Display for WindowFrameBound { } } +impl Visit for WindowFrameBound { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + WindowFrameBound::CurrentRow => ControlFlow::Continue(()), + WindowFrameBound::Preceding(e) | WindowFrameBound::Following(e) => e.visit(visitor), + } + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AddDropSync { @@ -2615,6 +2789,206 @@ impl fmt::Display for Statement { } } +impl Visit for Statement { + fn visit(&self, visitor: &mut V) -> ControlFlow { + visitor.visit_statement(self)?; + match self { + Statement::Close { .. } + | Statement::Discard { .. } + | Statement::SetRole { .. } + | Statement::SetNames { .. } + | Statement::SetNamesDefault { .. } + | Statement::ShowVariable { .. } + | Statement::Use { .. } + | Statement::StartTransaction { .. } + | Statement::SetTransaction { .. } + | Statement::Commit { .. } + | Statement::Rollback { .. } + | Statement::CreateSchema { .. } + | Statement::CreateDatabase { .. } + | Statement::Grant { .. } + | Statement::Revoke { .. } + | Statement::Deallocate { .. } + | Statement::Kill { .. } + | Statement::Savepoint { .. } => ControlFlow::Continue(()), + Statement::Msck { table_name, .. } | Statement::Copy { table_name, .. } => { + visitor.visit_table(table_name) + } + Statement::CreateVirtualTable { name, .. } => visitor.visit_table(name), + Statement::Analyze { + table_name, + partitions, + .. + } => { + visitor.visit_table(table_name)?; + partitions.visit(visitor) + } + Statement::Truncate { + table_name, + partitions, + .. + } => { + visitor.visit_table(table_name)?; + partitions.visit(visitor) + } + Statement::Query(query) => query.visit(visitor), + Statement::Insert { + table_name, + partitioned, + on, + returning, + .. + } => { + visitor.visit_table(table_name)?; + partitioned.visit(visitor)?; + on.visit(visitor)?; + returning.visit(visitor) + } + Statement::Directory { source, .. } => source.visit(visitor), + Statement::Update { + table, + assignments, + from, + selection, + returning, + } => { + table.visit(visitor)?; + assignments.visit(visitor)?; + from.visit(visitor)?; + selection.visit(visitor)?; + returning.visit(visitor) + } + Statement::Delete { + table_name, + using, + selection, + returning, + } => { + table_name.visit(visitor)?; + using.visit(visitor)?; + selection.visit(visitor)?; + returning.visit(visitor) + } + Statement::CreateView { name, query, .. } => { + visitor.visit_table(name)?; + query.visit(visitor) + } + Statement::CreateTable { + name, + columns, + constraints, + hive_distribution, + hive_formats, + query, + .. + } => { + visitor.visit_table(name)?; + columns.visit(visitor)?; + constraints.visit(visitor)?; + hive_distribution.visit(visitor)?; + hive_formats.visit(visitor)?; + query.visit(visitor) + } + Statement::CreateIndex { + table_name, + columns, + .. + } => { + visitor.visit_table(table_name)?; + columns.visit(visitor) + } + Statement::CreateRole { + connection_limit, .. + } => connection_limit.visit(visitor), + Statement::AlterTable { + name, operation, .. + } => { + visitor.visit_table(name)?; + operation.visit(visitor) + } + Statement::Drop { + object_type, names, .. + } => { + if matches!(object_type, ObjectType::Table | ObjectType::View) { + names + .iter() + .try_for_each(|name| visitor.visit_table(name))? + } + ControlFlow::Continue(()) + } + Statement::Declare { query, .. } => query.visit(visitor), + Statement::Fetch { into, .. } => { + if let Some(into) = into { + visitor.visit_table(into)? + } + ControlFlow::Continue(()) + } + Statement::SetVariable { value, .. } => value.visit(visitor), + Statement::SetTimeZone { value, .. } => value.visit(visitor), + Statement::ShowFunctions { filter, .. } + | Statement::ShowVariables { filter, .. } + | Statement::ShowTables { filter, .. } + | Statement::ShowCollation { filter, .. } => filter.visit(visitor), + Statement::ShowCreate { obj_type, obj_name } => { + if matches!(obj_type, ShowCreateObject::View | ShowCreateObject::Table) { + visitor.visit_table(obj_name)? + } + ControlFlow::Continue(()) + } + Statement::ShowColumns { + table_name, filter, .. + } => { + visitor.visit_table(table_name)?; + filter.visit(visitor) + } + Statement::Comment { + object_type, + object_name, + .. + } => { + if matches!(object_type, CommentObject::Table) { + visitor.visit_table(object_name)?; + } + ControlFlow::Continue(()) + } + Statement::CreateFunction { args, params, .. } => { + args.visit(visitor)?; + params.visit(visitor) + } + Statement::Assert { condition, message, .. } => { + condition.visit(visitor)?; + message.visit(visitor) + }, + Statement::Execute { parameters, .. } => parameters.visit(visitor), + Statement::Prepare { statement, .. } => statement.visit(visitor), + Statement::ExplainTable { table_name, .. } => visitor.visit_table(table_name), + Statement::Explain { statement, .. } => statement.visit(visitor), + Statement::Merge { + table, + source, + on, + clauses, + .. + } => { + table.visit(visitor)?; + source.visit(visitor)?; + on.visit(visitor)?; + clauses.visit(visitor) + } + Statement::Cache { + table_name, query, .. + } => { + visitor.visit_table(table_name)?; + query.visit(visitor) + } + Statement::UNCache { table_name, .. } => visitor.visit_table(table_name), + Statement::CreateSequence { + sequence_options, .. + } => sequence_options.visit(visitor), + } + } +} + /// Can use to describe options in create sequence or table column type identity /// [ INCREMENT [ BY ] increment ] /// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ] @@ -2681,6 +3055,16 @@ impl fmt::Display for SequenceOptions { } } +impl Visit for SequenceOptions { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::IncrementBy(e, _) | Self::StartWith(e, _) | Self::Cache(e) => e.visit(visitor), + Self::MinValue(e) | Self::MaxValue(e) => e.visit(visitor), + Self::Cycle(_) => ControlFlow::Continue(()), + } + } +} + /// Can use to describe options in create sequence or table column type identity /// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -2694,6 +3078,15 @@ pub enum MinMaxValue { Some(Expr), } +impl Visit for MinMaxValue { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::Empty | Self::None => ControlFlow::Continue(()), + MinMaxValue::Some(e) => e.visit(visitor), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[non_exhaustive] @@ -2704,28 +3097,6 @@ pub enum OnInsert { OnConflict(OnConflict), } -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct OnConflict { - pub conflict_target: Vec, - pub action: OnConflictAction, -} -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum OnConflictAction { - DoNothing, - DoUpdate(DoUpdate), -} - -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct DoUpdate { - /// Column assignments - pub assignments: Vec, - /// WHERE - pub selection: Option, -} - impl fmt::Display for OnInsert { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -2738,6 +3109,23 @@ impl fmt::Display for OnInsert { } } } + +impl Visit for OnInsert { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + OnInsert::DuplicateKeyUpdate(assignment) => assignment.visit(visitor), + OnInsert::OnConflict(c) => c.visit(visitor), + } + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct OnConflict { + pub conflict_target: Vec, + pub action: OnConflictAction, +} + impl fmt::Display for OnConflict { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, " ON CONFLICT")?; @@ -2747,6 +3135,20 @@ impl fmt::Display for OnConflict { write!(f, " {}", self.action) } } + +impl Visit for OnConflict { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.action.visit(visitor) + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum OnConflictAction { + DoNothing, + DoUpdate(DoUpdate), +} + impl fmt::Display for OnConflictAction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -2769,6 +3171,31 @@ impl fmt::Display for OnConflictAction { } } +impl Visit for OnConflictAction { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::DoNothing => ControlFlow::Continue(()), + Self::DoUpdate(u) => u.visit(visitor), + } + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DoUpdate { + /// Column assignments + pub assignments: Vec, + /// WHERE + pub selection: Option, +} + +impl Visit for DoUpdate { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.assignments.visit(visitor)?; + self.selection.visit(visitor) + } +} + /// Privileges granted in a GRANT statement or revoked in a REVOKE statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -2977,6 +3404,12 @@ impl fmt::Display for Assignment { } } +impl Visit for Assignment { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.value.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum FunctionArgExpr { @@ -2997,6 +3430,17 @@ impl fmt::Display for FunctionArgExpr { } } +impl Visit for FunctionArgExpr { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + FunctionArgExpr::Expr(e) => e.visit(visitor), + FunctionArgExpr::QualifiedWildcard(_) | FunctionArgExpr::Wildcard => { + ControlFlow::Continue(()) + } + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum FunctionArg { @@ -3013,6 +3457,14 @@ impl fmt::Display for FunctionArg { } } +impl Visit for FunctionArg { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::Unnamed(arg) | Self::Named { arg, .. } => arg.visit(visitor), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CloseCursor { @@ -3043,6 +3495,13 @@ pub struct Function { pub special: bool, } +impl Visit for Function { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.args.visit(visitor)?; + self.over.visit(visitor) + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AnalyzeFormat { @@ -3149,6 +3608,15 @@ impl fmt::Display for ListAgg { } } +impl Visit for ListAgg { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.expr.visit(visitor)?; + self.separator.visit(visitor)?; + self.on_overflow.visit(visitor)?; + self.within_group.visit(visitor) + } +} + /// The `ON OVERFLOW` clause of a LISTAGG invocation #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -3184,6 +3652,15 @@ impl fmt::Display for ListAggOnOverflow { } } +impl Visit for ListAggOnOverflow { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::Truncate { filler, .. } => filler.visit(visitor), + Self::Error => ControlFlow::Continue(()), + } + } +} + /// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] [ORDER BY ] [LIMIT ] )` /// Or `ARRAY_AGG( [ DISTINCT ] ) [ WITHIN GROUP ( ORDER BY ) ]` /// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake. @@ -3223,6 +3700,14 @@ impl fmt::Display for ArrayAgg { } } +impl Visit for ArrayAgg { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.expr.visit(visitor)?; + self.order_by.visit(visitor)?; + self.limit.visit(visitor) + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ObjectType { @@ -3286,6 +3771,20 @@ pub enum HiveDistributionStyle { NONE, } +impl Visit for HiveDistributionStyle { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + HiveDistributionStyle::PARTITIONED { columns } => columns.visit(visitor), + HiveDistributionStyle::CLUSTERED { sorted_by, .. } => sorted_by.visit(visitor), + HiveDistributionStyle::SKEWED { columns, on, .. } => { + columns.visit(visitor)?; + on.visit(visitor) + } + HiveDistributionStyle::NONE => ControlFlow::Continue(()), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum HiveRowFormat { @@ -3307,6 +3806,21 @@ pub enum HiveIOFormat { }, } +impl Visit for HiveIOFormat { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + HiveIOFormat::IOF { + input_format, + output_format, + } => { + input_format.visit(visitor)?; + output_format.visit(visitor) + } + HiveIOFormat::FileFormat { .. } => ControlFlow::Continue(()), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct HiveFormat { @@ -3315,6 +3829,12 @@ pub struct HiveFormat { pub location: Option, } +impl Visit for HiveFormat { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.storage.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct SqlOption { @@ -3402,6 +3922,17 @@ impl fmt::Display for ShowStatementFilter { } } +impl Visit for ShowStatementFilter { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + ShowStatementFilter::Like(_) | ShowStatementFilter::ILike(_) => { + ControlFlow::Continue(()) + } + ShowStatementFilter::Where(e) => e.visit(visitor), + } + } +} + /// Sqlite specific syntax /// /// https://sqlite.org/lang_conflict.html @@ -3643,6 +4174,27 @@ impl fmt::Display for MergeClause { } } +impl Visit for MergeClause { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + MergeClause::MatchedUpdate { + predicate, + assignments, + } => { + predicate.visit(visitor)?; + assignments.visit(visitor) + } + MergeClause::MatchedDelete(e) => e.visit(visitor), + MergeClause::NotMatched { + predicate, values, .. + } => { + predicate.visit(visitor)?; + values.visit(visitor) + } + } + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum DiscardObject { @@ -3739,6 +4291,12 @@ impl fmt::Display for CreateFunctionArg { } } +impl Visit for CreateFunctionArg { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.default_expr.visit(visitor) + } +} + /// The mode of an argument in CREATE FUNCTION. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -3816,6 +4374,12 @@ impl fmt::Display for CreateFunctionBody { } } +impl Visit for CreateFunctionBody { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.return_.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CreateFunctionUsing { diff --git a/src/ast/query.rs b/src/ast/query.rs index f813f44dd..89fa159dc 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -64,6 +64,18 @@ impl fmt::Display for Query { } } +impl Visit for Query { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.with.visit(visitor)?; + self.body.visit(visitor)?; + self.order_by.visit(visitor)?; + self.limit.visit(visitor)?; + self.offset.visit(visitor)?; + self.fetch.visit(visitor)?; + self.lock.visit(visitor) + } +} + /// A node in a tree, representing a "query body" expression, roughly: /// `SELECT ... [ {UNION|EXCEPT|INTERSECT} SELECT ...]` #[allow(clippy::large_enum_variant)] @@ -115,6 +127,29 @@ impl fmt::Display for SetExpr { } } +impl Visit for SetExpr { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + SetExpr::Select(s) => s.visit(visitor), + SetExpr::Query(q) => q.visit(visitor), + SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + } => { + op.visit(visitor)?; + set_quantifier.visit(visitor)?; + left.visit(visitor)?; + right.visit(visitor) + } + SetExpr::Values(v) => v.visit(visitor), + SetExpr::Insert(s) => s.visit(visitor), + SetExpr::Table(t) => t.visit(visitor), + } + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum SetOperator { @@ -133,6 +168,12 @@ impl fmt::Display for SetOperator { } } +impl Visit for SetOperator { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + /// A quantifier for [SetOperator]. // TODO: Restrict parsing specific SetQuantifier in some specific dialects. // For example, BigQuery does not support `DISTINCT` for `EXCEPT` and `INTERSECT` @@ -154,6 +195,12 @@ impl fmt::Display for SetQuantifier { } } +impl Visit for SetQuantifier { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// A [`TABLE` command]( https://www.postgresql.org/docs/current/sql-select.html#SQL-TABLE) @@ -178,6 +225,12 @@ impl fmt::Display for Table { } } +impl Visit for Table { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + /// A restricted variant of `SELECT` (without CTEs/`ORDER BY`), which may /// appear either as the only body item of a `Query`, or as an operand /// to a set operation like `UNION`. @@ -264,6 +317,23 @@ impl fmt::Display for Select { } } +impl Visit for Select { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.top.visit(visitor)?; + self.projection.visit(visitor)?; + self.into.visit(visitor)?; + self.from.visit(visitor)?; + self.lateral_views.visit(visitor)?; + self.selection.visit(visitor)?; + self.group_by.visit(visitor)?; + self.cluster_by.visit(visitor)?; + self.distribute_by.visit(visitor)?; + self.sort_by.visit(visitor)?; + self.having.visit(visitor)?; + self.qualify.visit(visitor) + } +} + /// A hive LATERAL VIEW with potential column aliases #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -298,6 +368,12 @@ impl fmt::Display for LateralView { } } +impl Visit for LateralView { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.lateral_view.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct With { @@ -316,6 +392,12 @@ impl fmt::Display for With { } } +impl Visit for With { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.cte_tables.visit(visitor) + } +} + /// A single CTE (used after `WITH`): `alias [(col1, col2, ...)] AS ( query )` /// The names in the column list before `AS`, when specified, replace the names /// of the columns returned by the query. The parser does not validate that the @@ -338,6 +420,12 @@ impl fmt::Display for Cte { } } +impl Visit for Cte { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.query.visit(visitor) + } +} + /// One item of the comma-separated list following `SELECT` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -352,6 +440,37 @@ pub enum SelectItem { Wildcard(WildcardAdditionalOptions), } +impl fmt::Display for SelectItem { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self { + SelectItem::UnnamedExpr(expr) => write!(f, "{}", expr), + SelectItem::ExprWithAlias { expr, alias } => write!(f, "{} AS {}", expr, alias), + SelectItem::QualifiedWildcard(prefix, additional_options) => { + write!(f, "{}.*", prefix)?; + write!(f, "{additional_options}")?; + Ok(()) + } + SelectItem::Wildcard(additional_options) => { + write!(f, "*")?; + write!(f, "{additional_options}")?; + Ok(()) + } + } + } +} + +impl Visit for SelectItem { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + SelectItem::UnnamedExpr(e) => e.visit(visitor), + SelectItem::ExprWithAlias { expr, .. } => expr.visit(visitor), + SelectItem::QualifiedWildcard(_, options) | SelectItem::Wildcard(options) => { + options.visit(visitor) + } + } + } +} + /// Additional options for wildcards, e.g. Snowflake `EXCLUDE` and Bigquery `EXCEPT`. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -374,6 +493,13 @@ impl fmt::Display for WildcardAdditionalOptions { } } +impl Visit for WildcardAdditionalOptions { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.opt_exclude.visit(visitor)?; + self.opt_except.visit(visitor) + } +} + /// Snowflake `EXCLUDE` information. /// /// # Syntax @@ -414,6 +540,12 @@ impl fmt::Display for ExcludeSelectItem { } } +impl Visit for ExcludeSelectItem { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + /// Bigquery `EXCEPT` information, with at least one column. /// /// # Syntax @@ -446,22 +578,9 @@ impl fmt::Display for ExceptSelectItem { } } -impl fmt::Display for SelectItem { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self { - SelectItem::UnnamedExpr(expr) => write!(f, "{}", expr), - SelectItem::ExprWithAlias { expr, alias } => write!(f, "{} AS {}", expr, alias), - SelectItem::QualifiedWildcard(prefix, additional_options) => { - write!(f, "{}.*", prefix)?; - write!(f, "{additional_options}")?; - Ok(()) - } - SelectItem::Wildcard(additional_options) => { - write!(f, "*")?; - write!(f, "{additional_options}")?; - Ok(()) - } - } +impl Visit for ExceptSelectItem { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) } } @@ -482,6 +601,13 @@ impl fmt::Display for TableWithJoins { } } +impl Visit for TableWithJoins { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.relation.visit(visitor)?; + self.joins.visit(visitor) + } +} + /// A table name or a parenthesized subquery with an optional alias #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -610,6 +736,52 @@ impl fmt::Display for TableFactor { } } +impl Visit for TableFactor { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::Table { + name, + alias, + args, + with_hints, + } => { + visitor.visit_table(name)?; + alias.visit(visitor)?; + args.visit(visitor)?; + with_hints.visit(visitor) + } + Self::Derived { + lateral: _, + subquery, + alias, + } => { + subquery.visit(visitor)?; + alias.visit(visitor) + } + Self::TableFunction { expr, alias } => { + expr.visit(visitor)?; + alias.visit(visitor) + } + Self::UNNEST { + alias, + array_expr, + with_offset: _, + with_offset_alias: _, + } => { + alias.visit(visitor)?; + array_expr.visit(visitor) + } + Self::NestedJoin { + table_with_joins, + alias, + } => { + table_with_joins.visit(visitor)?; + alias.visit(visitor) + } + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct TableAlias { @@ -627,6 +799,12 @@ impl fmt::Display for TableAlias { } } +impl Visit for TableAlias { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Join { @@ -721,6 +899,13 @@ impl fmt::Display for Join { } } +impl Visit for Join { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.relation.visit(visitor)?; + self.join_operator.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum JoinOperator { @@ -743,6 +928,24 @@ pub enum JoinOperator { OuterApply, } +impl Visit for JoinOperator { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + JoinOperator::Inner(c) + | JoinOperator::LeftOuter(c) + | JoinOperator::RightOuter(c) + | JoinOperator::FullOuter(c) + | JoinOperator::LeftSemi(c) + | JoinOperator::RightSemi(c) + | JoinOperator::LeftAnti(c) + | JoinOperator::RightAnti(c) => c.visit(visitor), + JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply => { + ControlFlow::Continue(()) + } + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum JoinConstraint { @@ -752,6 +955,17 @@ pub enum JoinConstraint { None, } +impl Visit for JoinConstraint { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + JoinConstraint::On(e) => e.visit(visitor), + JoinConstraint::Using(_) | JoinConstraint::Natural | JoinConstraint::None => { + ControlFlow::Continue(()) + } + } + } +} + /// An `ORDER BY` expression #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -780,6 +994,12 @@ impl fmt::Display for OrderByExpr { } } +impl Visit for OrderByExpr { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.expr.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Offset { @@ -793,6 +1013,13 @@ impl fmt::Display for Offset { } } +impl Visit for Offset { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.value.visit(visitor)?; + self.rows.visit(visitor) + } +} + /// Stores the keyword after `OFFSET ` #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -813,6 +1040,12 @@ impl fmt::Display for OffsetRows { } } +impl Visit for OffsetRows { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Fetch { @@ -833,6 +1066,12 @@ impl fmt::Display for Fetch { } } +impl Visit for Fetch { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.quantity.visit(visitor) + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum LockType { @@ -850,6 +1089,12 @@ impl fmt::Display for LockType { } } +impl Visit for LockType { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Top { @@ -871,6 +1116,12 @@ impl fmt::Display for Top { } } +impl Visit for Top { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.quantity.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Values { @@ -894,6 +1145,12 @@ impl fmt::Display for Values { } } +impl Visit for Values { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.rows.visit(visitor) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct SelectInto { @@ -912,3 +1169,9 @@ impl fmt::Display for SelectInto { write!(f, "INTO{}{}{} {}", temporary, unlogged, table, self.name) } } + +impl Visit for SelectInto { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} diff --git a/src/ast/value.rs b/src/ast/value.rs index 9a356e8bf..fa38f597b 100644 --- a/src/ast/value.rs +++ b/src/ast/value.rs @@ -13,11 +13,13 @@ #[cfg(not(feature = "std"))] use alloc::string::String; use core::fmt; +use core::ops::ControlFlow; #[cfg(feature = "bigdecimal")] use bigdecimal::BigDecimal; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use crate::ast::{Visit, Visitor}; /// Primitive SQL values such as number and string #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -214,3 +216,9 @@ impl fmt::Display for TrimWhereField { }) } } + +impl Visit for TrimWhereField { + fn visit(&self, _visitor: &mut V) -> ControlFlow { + ControlFlow::Continue(()) + } +} diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs new file mode 100644 index 000000000..5a531ce9e --- /dev/null +++ b/src/ast/visitor.rs @@ -0,0 +1,126 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::ast::{Expr, ObjectName, Statement}; +use core::ops::ControlFlow; + +/// A type that can be visited by a `visitor` +pub trait Visit { + fn visit(&self, visitor: &mut V) -> ControlFlow; +} + +impl Visit for Option { + fn visit(&self, visitor: &mut V) -> ControlFlow { + if let Some(s) = self { + s.visit(visitor)?; + } + ControlFlow::Continue(()) + } +} + +impl Visit for Vec { + fn visit(&self, visitor: &mut V) -> ControlFlow { + for v in self { + v.visit(visitor)?; + } + ControlFlow::Continue(()) + } +} + +impl Visit for Box { + fn visit(&self, visitor: &mut V) -> ControlFlow { + T::visit(self, visitor) + } +} + +/// A visitor that can be used to walk an AST tree +pub trait Visitor { + type Break; + + /// Invoked for any tables, virtual or otherwise that appear in the AST + fn visit_table(&mut self, _table: &ObjectName) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any expressions that appear in the AST + fn visit_expr(&mut self, _expr: &Expr) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any statements that appear in the AST + fn visit_statement(&mut self, _statement: &Statement) -> ControlFlow { + ControlFlow::Continue(()) + } +} + +struct TableVisitor(F); + +impl ControlFlow> Visitor for TableVisitor { + type Break = E; + + fn visit_table(&mut self, table: &ObjectName) -> ControlFlow { + self.0(table) + } +} + +/// Invokes the provided closure on all tables present in v +pub fn visit_tables(v: &V, f: F) -> ControlFlow +where + V: Visit, + F: FnMut(&ObjectName) -> ControlFlow, +{ + let mut visitor = TableVisitor(f); + v.visit(&mut visitor)?; + ControlFlow::Continue(()) +} + +struct ExprVisitor(F); + +impl ControlFlow> Visitor for ExprVisitor { + type Break = E; + + fn visit_expr(&mut self, expr: &Expr) -> ControlFlow { + self.0(expr) + } +} + +/// Invokes the provided closure on all expressions present in v +pub fn visit_expressions(v: &V, f: F) -> ControlFlow +where + V: Visit, + F: FnMut(&Expr) -> ControlFlow, +{ + let mut visitor = ExprVisitor(f); + v.visit(&mut visitor)?; + ControlFlow::Continue(()) +} + +struct StatementVisitor(F); + +impl ControlFlow> Visitor for StatementVisitor { + type Break = E; + + fn visit_statement(&mut self, statement: &Statement) -> ControlFlow { + self.0(statement) + } +} + +/// Invokes the provided closure on all statements present in v +pub fn visit_statements(v: &V, f: F) -> ControlFlow +where + V: Visit, + F: FnMut(&Statement) -> ControlFlow, +{ + let mut visitor = StatementVisitor(f); + v.visit(&mut visitor)?; + ControlFlow::Continue(()) +}