Skip to content

Commit dcd2b1a

Browse files
davidm-dbcloud-fan
authored andcommitted
[SPARK-52134] Move execution logic to SqlScriptingExecution and enable Spark Connect path
### What changes were proposed in this pull request? #### Spark Connect execution Move the script execution from `SparkSession#sql` to `QueryExecution#lazyAnalyzed`. This allows `QueryExecution` to receive the original parsed logical plan for scripting, which will be used to detect script execution in Spark Connect to treat them as commands. #### executeSqlScript refactor Moving the `executeSqlScript` logic from `SparkSession` to `SqlScriptingExecution's` object. ### Why are the changes needed? SQL Scripting improvements. ### Does this PR introduce _any_ user-facing change? No. This PR enables new functionality though (execution through Spark Connect), but the results are remaining the same. ### How was this patch tested? Already existing tests confirm that refactor of execution logic doesn't affect anything. Test added to confirm that execution through Spark Connect is not failing. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50895 from davidm-db/execute_sql_script_refactor. Lead-authored-by: David Milicevic <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f087cc4 commit dcd2b1a

File tree

6 files changed

+164
-96
lines changed

6 files changed

+164
-96
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
5454
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
5555
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
5656
import org.apache.spark.sql.catalyst.plans.logical
57-
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TimeModes, TransformWithState, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateEventTimeWatermarkColumn, UpdateStarAction}
57+
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, CompoundBody, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TimeModes, TransformWithState, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateEventTimeWatermarkColumn, UpdateStarAction}
5858
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
5959
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
6060
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -2795,8 +2795,9 @@ class SparkConnectPlanner(
27952795
s"SQL command expects either a SQL or a WithRelations, but got $other")
27962796
}
27972797

2798-
// Check if commands have been executed.
2798+
// Check if command or SQL Script has been executed.
27992799
val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
2800+
val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
28002801
val rows = df.logicalPlan match {
28012802
case lr: LocalRelation => lr.data
28022803
case cr: CommandResult => cr.rows
@@ -2808,7 +2809,7 @@ class SparkConnectPlanner(
28082809
val result = SqlCommandResult.newBuilder()
28092810
// Only filled when isCommand
28102811
val metrics = ExecutePlanResponse.Metrics.newBuilder()
2811-
if (isCommand) {
2812+
if (isCommand || isSqlScript) {
28122813
// Convert the results to Arrow.
28132814
val schema = df.schema
28142815
val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
*/
1717
package org.apache.spark.sql.connect
1818

19+
import java.io.ByteArrayInputStream
1920
import java.util.{TimeZone, UUID}
2021

2122
import scala.reflect.runtime.universe.TypeTag
2223

2324
import org.apache.arrow.memory.RootAllocator
25+
import org.apache.arrow.vector.ipc.ArrowStreamReader
2426
import org.scalatest.concurrent.{Eventually, TimeLimits}
2527
import org.scalatest.time.Span
2628
import org.scalatest.time.SpanSugar._
2729

2830
import org.apache.spark.connect.proto
31+
import org.apache.spark.connect.proto.ExecutePlanResponse
2932
import org.apache.spark.sql.catalyst.ScalaReflection
3033
import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
3134
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
@@ -143,6 +146,21 @@ trait SparkConnectServerTest extends SharedSparkSession {
143146
proto.Plan.newBuilder().setRoot(dsl.sql(query)).build()
144147
}
145148

149+
protected def buildSqlCommandPlan(sqlCommand: String) = {
150+
proto.Plan
151+
.newBuilder()
152+
.setCommand(
153+
proto.Command
154+
.newBuilder()
155+
.setSqlCommand(
156+
proto.SqlCommand
157+
.newBuilder()
158+
.setSql(sqlCommand)
159+
.build())
160+
.build())
161+
.build()
162+
}
163+
146164
protected def buildLocalRelation[A <: Product: TypeTag](data: Seq[A]) = {
147165
val encoder = ScalaReflection.encoderFor[A]
148166
val arrowData =
@@ -305,4 +323,43 @@ trait SparkConnectServerTest extends SharedSparkSession {
305323
val plan = buildPlan(query)
306324
runQuery(plan, queryTimeout, iterSleep)
307325
}
326+
327+
protected def checkSqlCommandResponse(
328+
result: ExecutePlanResponse.SqlCommandResult,
329+
expected: Seq[Seq[Any]]): Unit = {
330+
// Extract the serialized Arrow data as a byte array.
331+
val dataBytes = result.getRelation.getLocalRelation.getData.toByteArray
332+
333+
// Create an ArrowStreamReader to deserialize the data.
334+
val allocator = new RootAllocator(Long.MaxValue)
335+
val inputStream = new ByteArrayInputStream(dataBytes)
336+
val reader = new ArrowStreamReader(inputStream, allocator)
337+
338+
try {
339+
// Read the schema and data.
340+
val root = reader.getVectorSchemaRoot
341+
// Load the first batch of data.
342+
reader.loadNextBatch()
343+
344+
// Get dimensions.
345+
val rowCount = root.getRowCount
346+
val colCount = root.getFieldVectors.size
347+
assert(rowCount == expected.length, "Row count mismatch")
348+
assert(colCount == expected.head.length, "Column count mismatch")
349+
350+
// Compare to expected.
351+
for (i <- 0 until rowCount) {
352+
for (j <- 0 until colCount) {
353+
val col = root.getFieldVectors.get(j)
354+
val value = col.getObject(i)
355+
print(value)
356+
assert(value == expected(i)(j), s"Value mismatch at ($i, $j)")
357+
}
358+
}
359+
} finally {
360+
// Clean up resources.
361+
reader.close()
362+
allocator.close()
363+
}
364+
}
308365
}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,27 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
3333
// were all already in the buffer.
3434
val BIG_ENOUGH_QUERY = "select * from range(1000000)"
3535

36+
test("SQL Script over Spark Connect.") {
37+
val sessionId = UUID.randomUUID.toString()
38+
val userId = "ScriptUser"
39+
val sqlScriptText =
40+
"""BEGIN
41+
|IF 1 = 1 THEN
42+
| SELECT 1;
43+
|ELSE
44+
| SELECT 2;
45+
|END IF;
46+
|END
47+
""".stripMargin
48+
withClient(sessionId = sessionId, userId = userId) { client =>
49+
// this will create the session, and then ReleaseSession at the end of withClient.
50+
val enableSqlScripting = client.execute(buildPlan("SET spark.sql.scripting.enabled=true"))
51+
enableSqlScripting.hasNext // trigger execution
52+
val query = client.execute(buildSqlCommandPlan(sqlScriptText))
53+
checkSqlCommandResponse(query.next().getSqlCommandResult, Seq(Seq(1)))
54+
}
55+
}
56+
3657
test("Execute is sent eagerly to the server upon iterator creation") {
3758
// This behavior changed with grpc upgrade from 1.56.0 to 1.59.0.
3859
// Testing to be aware of future changes.

sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala

Lines changed: 16 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ import org.apache.spark.sql.artifact.ArtifactManager
4242
import org.apache.spark.sql.catalyst._
4343
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
4444
import org.apache.spark.sql.catalyst.encoders._
45-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
45+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
4646
import org.apache.spark.sql.catalyst.parser.ParserInterface
47-
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, Range}
48-
import org.apache.spark.sql.catalyst.types.DataTypeUtils
47+
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, Range}
4948
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
5049
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
5150
import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions
@@ -56,7 +55,6 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
5655
import org.apache.spark.sql.functions.lit
5756
import org.apache.spark.sql.internal._
5857
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
59-
import org.apache.spark.sql.scripting.SqlScriptingExecution
6058
import org.apache.spark.sql.sources.BaseRelation
6159
import org.apache.spark.sql.types.{DataType, StructType}
6260
import org.apache.spark.sql.util.ExecutionListenerManager
@@ -432,50 +430,6 @@ class SparkSession private(
432430
| Everything else |
433431
* ----------------- */
434432

435-
/**
436-
* Executes given script and return the result of the last statement.
437-
* If script contains no queries, an empty `DataFrame` is returned.
438-
*
439-
* @param script A SQL script to execute.
440-
* @param args A map of parameter names to SQL literal expressions.
441-
*
442-
* @return The result as a `DataFrame`.
443-
*/
444-
private def executeSqlScript(
445-
script: CompoundBody,
446-
args: Map[String, Expression] = Map.empty): DataFrame = {
447-
val sse = new SqlScriptingExecution(script, this, args)
448-
sse.withLocalVariableManager {
449-
var result: Option[Seq[Row]] = None
450-
451-
// We must execute returned df before calling sse.getNextResult again because sse.hasNext
452-
// advances the script execution and executes all statements until the next result. We must
453-
// collect results immediately to maintain execution order.
454-
// This ensures we respect the contract of SqlScriptingExecution API.
455-
var df: Option[DataFrame] = sse.getNextResult
456-
var resultSchema: Option[StructType] = None
457-
while (df.isDefined) {
458-
sse.withErrorHandling {
459-
// Collect results from the current DataFrame.
460-
result = Some(df.get.collect().toSeq)
461-
resultSchema = Some(df.get.schema)
462-
}
463-
df = sse.getNextResult
464-
}
465-
466-
if (result.isEmpty) {
467-
emptyDataFrame
468-
} else {
469-
// If `result` is defined, then `resultSchema` must be defined as well.
470-
assert(resultSchema.isDefined)
471-
472-
val attributes = DataTypeUtils.toAttributes(resultSchema.get)
473-
Dataset.ofRows(
474-
self, LocalRelation.fromExternalRows(attributes, result.get))
475-
}
476-
}
477-
}
478-
479433
/**
480434
* Executes a SQL query substituting positional parameters by the given arguments,
481435
* returning the result as a `DataFrame`.
@@ -495,30 +449,17 @@ class SparkSession private(
495449
withActive {
496450
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
497451
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
498-
parsedPlan match {
499-
case compoundBody: CompoundBody =>
500-
if (args.nonEmpty) {
501-
// Positional parameters are not supported for SQL scripting.
502-
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
503-
}
504-
compoundBody
505-
case logicalPlan: LogicalPlan =>
506-
if (args.nonEmpty) {
507-
PosParameterizedQuery(logicalPlan, args.map(lit(_).expr).toImmutableArraySeq)
508-
} else {
509-
logicalPlan
510-
}
452+
if (args.nonEmpty) {
453+
if (parsedPlan.isInstanceOf[CompoundBody]) {
454+
// Positional parameters are not supported for SQL scripting.
455+
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
456+
}
457+
PosParameterizedQuery(parsedPlan, args.map(lit(_).expr).toImmutableArraySeq)
458+
} else {
459+
parsedPlan
511460
}
512461
}
513-
514-
plan match {
515-
case compoundBody: CompoundBody =>
516-
// Execute the SQL script.
517-
executeSqlScript(compoundBody)
518-
case logicalPlan: LogicalPlan =>
519-
// Execute the standalone SQL statement.
520-
Dataset.ofRows(self, plan, tracker)
521-
}
462+
Dataset.ofRows(self, plan, tracker)
522463
}
523464

524465
/** @inheritdoc */
@@ -549,26 +490,13 @@ class SparkSession private(
549490
withActive {
550491
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
551492
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
552-
parsedPlan match {
553-
case compoundBody: CompoundBody =>
554-
compoundBody
555-
case logicalPlan: LogicalPlan =>
556-
if (args.nonEmpty) {
557-
NameParameterizedQuery(logicalPlan, args.transform((_, v) => lit(v).expr))
558-
} else {
559-
logicalPlan
560-
}
493+
if (args.nonEmpty) {
494+
NameParameterizedQuery(parsedPlan, args.transform((_, v) => lit(v).expr))
495+
} else {
496+
parsedPlan
561497
}
562498
}
563-
564-
plan match {
565-
case compoundBody: CompoundBody =>
566-
// Execute the SQL script.
567-
executeSqlScript(compoundBody, args.transform((_, v) => lit(v).expr))
568-
case logicalPlan: LogicalPlan =>
569-
// Execute the standalone SQL statement.
570-
Dataset.ofRows(self, plan, tracker)
571-
}
499+
Dataset.ofRows(self, plan, tracker)
572500
}
573501

574502
/** @inheritdoc */

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row}
3333
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
34-
import org.apache.spark.sql.catalyst.analysis.{LazyExpression, UnsupportedOperationChecker}
34+
import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker}
3535
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
3636
import org.apache.spark.sql.catalyst.plans.QueryPlan
37-
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
37+
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
3838
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
3939
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
4040
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -46,6 +46,7 @@ import org.apache.spark.sql.execution.exchange.EnsureRequirements
4646
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
4747
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator}
4848
import org.apache.spark.sql.internal.SQLConf
49+
import org.apache.spark.sql.scripting.SqlScriptingExecution
4950
import org.apache.spark.sql.streaming.OutputMode
5051
import org.apache.spark.util.{LazyTry, Utils}
5152
import org.apache.spark.util.ArrayImplicits._
@@ -93,16 +94,26 @@ class QueryExecution(
9394
}
9495

9596
private val lazyAnalyzed = LazyTry {
97+
val withScriptExecuted = logical match {
98+
// Execute the SQL script. Script doesn't need to go through the analyzer as Spark will run
99+
// each statement as individual query.
100+
case NameParameterizedQuery(compoundBody: CompoundBody, argNames, argValues) =>
101+
val args = argNames.zip(argValues).toMap
102+
SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody, args)
103+
case compoundBody: CompoundBody =>
104+
SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody)
105+
case _ => logical
106+
}
96107
try {
97108
val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
98109
// We can't clone `logical` here, which will reset the `_analyzed` flag.
99-
sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
110+
sparkSession.sessionState.analyzer.executeAndCheck(withScriptExecuted, tracker)
100111
}
101112
tracker.setAnalyzed(plan)
102113
plan
103114
} catch {
104115
case NonFatal(e) =>
105-
tracker.setAnalysisFailed(logical)
116+
tracker.setAnalysisFailed(withScriptExecuted)
106117
throw e
107118
}
108119
}

0 commit comments

Comments
 (0)