Skip to content

Commit 594c29b

Browse files
committed
[SPARK-52494] Support colon-sign opeorator syntax to access Variant fields.
1 parent 7b49919 commit 594c29b

File tree

11 files changed

+290
-6
lines changed

11 files changed

+290
-6
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,7 @@ primaryExpression
12081208
| constant #constantDefault
12091209
| ASTERISK exceptClause? #star
12101210
| qualifiedName DOT ASTERISK exceptClause? #star
1211+
| col=primaryExpression COLON path=semiStructuredExtractionPath #semiStructuredExtract
12111212
| LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor
12121213
| LEFT_PAREN query RIGHT_PAREN #subqueryExpression
12131214
| functionName LEFT_PAREN (setQuantifier? argument+=functionArgument
@@ -1230,6 +1231,35 @@ primaryExpression
12301231
FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN #overlay
12311232
;
12321233

1234+
semiStructuredExtractionPath
1235+
: jsonPathFirstPart (jsonPathParts)*
1236+
;
1237+
1238+
jsonPathIdentifier
1239+
: identifier
1240+
| BACKQUOTED_IDENTIFIER
1241+
;
1242+
1243+
jsonPathBracketedIdentifier
1244+
: LEFT_BRACKET stringLit RIGHT_BRACKET
1245+
;
1246+
1247+
jsonPathFirstPart
1248+
: jsonPathIdentifier
1249+
| jsonPathBracketedIdentifier
1250+
| DOT
1251+
| LEFT_BRACKET INTEGER_VALUE RIGHT_BRACKET
1252+
| LEFT_BRACKET ASTERISK RIGHT_BRACKET
1253+
;
1254+
1255+
jsonPathParts
1256+
: DOT jsonPathIdentifier
1257+
| jsonPathBracketedIdentifier
1258+
| LEFT_BRACKET INTEGER_VALUE RIGHT_BRACKET
1259+
| LEFT_BRACKET ASTERISK RIGHT_BRACKET
1260+
| LEFT_BRACKET identifier RIGHT_BRACKET
1261+
;
1262+
12331263
literalType
12341264
: DATE
12351265
| TIME
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.expressions.variant.VariantGet
22+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{SEMI_STRUCTURED_EXTRACT, TreePattern}
25+
import org.apache.spark.sql.types.{DataType, StringType, VariantType}
26+
import org.apache.spark.unsafe.types.UTF8String
27+
28+
/**
29+
* Represents the extraction of data from a field that contains semi-structured data. The
30+
* semi-structured format can be anything (JSON, key-value delimited, etc), and that information
31+
* comes from the child expression's column metadata.
32+
* @param child The semi-structured column
33+
* @param field The field to extract
34+
*/
35+
case class SemiStructuredExtract(
36+
child: Expression, field: String) extends UnaryExpression with Unevaluable {
37+
override lazy val resolved = false
38+
override def dataType: DataType = StringType
39+
40+
final override val nodePatterns: Seq[TreePattern] = Seq(SEMI_STRUCTURED_EXTRACT)
41+
42+
override protected def withNewChildInternal(newChild: Expression): SemiStructuredExtract =
43+
copy(child = newChild)
44+
}
45+
46+
/**
47+
* Replaces SemiStructuredExtract expressions by extracting the specified field from the
48+
* semi-structured column (only VariantType is supported for now).
49+
*/
50+
case object ExtractSemiStructuredFields extends Rule[LogicalPlan] {
51+
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
52+
_.containsPattern(SEMI_STRUCTURED_EXTRACT), ruleId) {
53+
case SemiStructuredExtract(column, field) if column.resolved =>
54+
if (column.dataType.isInstanceOf[VariantType]) {
55+
VariantGet(column, Literal(UTF8String.fromString(field)), VariantType, failOnError = true)
56+
} else {
57+
throw new AnalysisException(
58+
errorClass = "COLUMN_IS_NOT_VARIANT_TYPE", messageParameters = Map.empty)
59+
}
60+
}
61+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, Str
3333
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
3434
import org.apache.spark.util.Utils
3535

36-
private[this] sealed trait PathInstruction
37-
private[this] object PathInstruction {
36+
sealed trait PathInstruction
37+
object PathInstruction {
3838
private[expressions] case object Subscript extends PathInstruction
3939
private[expressions] case object Wildcard extends PathInstruction
4040
private[expressions] case object Key extends PathInstruction
4141
private[expressions] case class Index(index: Long) extends PathInstruction
42-
private[expressions] case class Named(name: String) extends PathInstruction
42+
case class Named(name: String) extends PathInstruction
4343
}
4444

4545
private[this] sealed trait WriteStyle
@@ -49,7 +49,7 @@ private[this] object WriteStyle {
4949
private[expressions] case object FlattenStyle extends WriteStyle
5050
}
5151

52-
private[this] object JsonPathParser extends RegexParsers {
52+
object JsonPathParser extends RegexParsers {
5353
import PathInstruction._
5454

5555
def root: Parser[Char] = '$'

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS
3636
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, ClusterBySpec}
3737
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last}
39+
import org.apache.spark.sql.catalyst.expressions.json.JsonPathParser
40+
import org.apache.spark.sql.catalyst.expressions.json.PathInstruction.Named
3941
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
4042
import org.apache.spark.sql.catalyst.plans._
4143
import org.apache.spark.sql.catalyst.plans.logical._
@@ -3322,6 +3324,24 @@ class AstBuilder extends DataTypeAstBuilder
33223324
}
33233325
}
33243326

3327+
/**
3328+
* Create a [[SemiStructuredExtract]] expression.
3329+
*/
3330+
override def visitSemiStructuredExtract(
3331+
ctx: SemiStructuredExtractContext): Expression = withOrigin(ctx) {
3332+
val field = ctx.path.getText
3333+
// When `field` starts with a bracket, do not add a `.` as the bracket already implies nesting
3334+
// Also the bracket will imply case sensitive field extraction.
3335+
val path = if (field.startsWith("[")) "$" + field else s"$$.$field"
3336+
val parsedPath = JsonPathParser.parse(path)
3337+
if (parsedPath.isEmpty) {
3338+
throw new ParseException(errorClass = "PARSE_SYNTAX_ERROR", ctx = ctx)
3339+
}
3340+
val potentialAlias = parsedPath.get.collect { case Named(name) => name }.lastOption
3341+
val node = SemiStructuredExtract(expression(ctx.col), path)
3342+
potentialAlias.map { colName => Alias(node, colName)() }.getOrElse(node)
3343+
}
3344+
33253345
/**
33263346
* Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
33273347
* quoted in ``

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ object RuleIdCollection {
113113
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
114114
"org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" ::
115115
"org.apache.spark.sql.catalyst.analysis.ResolveTableConstraints" ::
116+
"org.apache.spark.sql.catalyst.expressions.ExtractSemiStructuredFields" ::
116117
// Catalyst Optimizer rules
117118
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
118119
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ object TreePattern extends Enumeration {
8080
val REGEXP_EXTRACT_FAMILY: Value = Value
8181
val REGEXP_REPLACE: Value = Value
8282
val RUNTIME_REPLACEABLE: Value = Value
83+
val SEMI_STRUCTURED_EXTRACT: Value = Value
8384
val SCALAR_SUBQUERY: Value = Value
8485
val SCALAR_SUBQUERY_REFERENCE: Value = Value
8586
val SCALA_UDF: Value = Value

sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.artifact.ArtifactManager
2222
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDataSource, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry}
2323
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
2424
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog}
25-
import org.apache.spark.sql.catalyst.expressions.Expression
25+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields}
2626
import org.apache.spark.sql.catalyst.optimizer.Optimizer
2727
import org.apache.spark.sql.catalyst.parser.ParserInterface
2828
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -244,6 +244,7 @@ abstract class BaseSessionStateBuilder(
244244
new EvalSubqueriesForTimeTravel +:
245245
new ResolveTranspose(session) +:
246246
new InvokeProcedures(session) +:
247+
ExtractSemiStructuredFields +:
247248
customResolutionRules
248249

249250
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- !query
3+
select parse_json('{ "price": 5 }'):price
4+
-- !query analysis
5+
Project [variant_get(parse_json({ "price": 5 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) AS price#x]
6+
+- OneRowRelation
7+
8+
9+
-- !query
10+
select parse_json('{ "price": 30 }'):price::decimal(5, 2)
11+
-- !query analysis
12+
Project [cast(variant_get(parse_json({ "price": 30 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) as decimal(5,2)) AS price#x]
13+
+- OneRowRelation
14+
15+
16+
-- !query
17+
select parse_json('{ "price": 30 }'):price::string
18+
-- !query analysis
19+
Project [cast(variant_get(parse_json({ "price": 30 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) as string) AS price#x]
20+
+- OneRowRelation
21+
22+
23+
-- !query
24+
select parse_json('{ "price": 12345.678 }'):price::decimal(3, 2)
25+
-- !query analysis
26+
Project [cast(variant_get(parse_json({ "price": 12345.678 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) as decimal(3,2)) AS price#x]
27+
+- OneRowRelation
28+
29+
30+
-- !query
31+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::double
32+
-- !query analysis
33+
Project [cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].price, VariantType, true, Some(America/Los_Angeles)) as double) AS price#x]
34+
+- OneRowRelation
35+
36+
37+
-- !query
38+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::int
39+
-- !query analysis
40+
Project [cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].price, VariantType, true, Some(America/Los_Angeles)) as int) AS price#x]
41+
+- OneRowRelation
42+
43+
44+
-- !query
45+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model
46+
-- !query analysis
47+
Project [variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].model, VariantType, true, Some(America/Los_Angeles)) AS model#x]
48+
+- OneRowRelation
49+
50+
51+
-- !query
52+
select substr(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model, 2, 3)
53+
-- !query analysis
54+
Project [substr(cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].model, VariantType, true, Some(America/Los_Angeles)) as string), 2, 3) AS substr(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[0].model) AS model, 2, 3)#x]
55+
+- OneRowRelation
56+
57+
58+
-- !query
59+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double
60+
-- !query analysis
61+
Project [cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[1].price, VariantType, true, Some(America/Los_Angeles)) as double) AS price#x]
62+
+- OneRowRelation
63+
64+
65+
-- !query
66+
select ceil(sqrt(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double))
67+
-- !query analysis
68+
Project [CEIL(SQRT(cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[1].price, VariantType, true, Some(America/Los_Angeles)) as double))) AS CEIL(SQRT(CAST(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[1].price) AS price AS DOUBLE)))#xL]
69+
+- OneRowRelation
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
-- Simple field extraction and type casting.
2+
select parse_json('{ "price": 5 }'):price;
3+
select parse_json('{ "price": 30 }'):price::decimal(5, 2);
4+
select parse_json('{ "price": 30 }'):price::string;
5+
-- Applying an invalid function.
6+
select parse_json('{ "price": 12345.678 }'):price::decimal(3, 2);
7+
-- Access field in an array and feed it into functions.
8+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::double;
9+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::int;
10+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model;
11+
select substr(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model, 2, 3);
12+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double;
13+
select ceil(sqrt(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double));
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- !query
3+
select parse_json('{ "price": 5 }'):price
4+
-- !query schema
5+
struct<price:variant>
6+
-- !query output
7+
5
8+
9+
10+
-- !query
11+
select parse_json('{ "price": 30 }'):price::decimal(5, 2)
12+
-- !query schema
13+
struct<price:decimal(5,2)>
14+
-- !query output
15+
30.00
16+
17+
18+
-- !query
19+
select parse_json('{ "price": 30 }'):price::string
20+
-- !query schema
21+
struct<price:string>
22+
-- !query output
23+
30
24+
25+
26+
-- !query
27+
select parse_json('{ "price": 12345.678 }'):price::decimal(3, 2)
28+
-- !query schema
29+
struct<>
30+
-- !query output
31+
org.apache.spark.SparkRuntimeException
32+
{
33+
"errorClass" : "INVALID_VARIANT_CAST",
34+
"sqlState" : "22023",
35+
"messageParameters" : {
36+
"dataType" : "\"DECIMAL(3,2)\"",
37+
"value" : "12345.678"
38+
}
39+
}
40+
41+
42+
-- !query
43+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::double
44+
-- !query schema
45+
struct<price:double>
46+
-- !query output
47+
6.12
48+
49+
50+
-- !query
51+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::int
52+
-- !query schema
53+
struct<price:int>
54+
-- !query output
55+
6
56+
57+
58+
-- !query
59+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model
60+
-- !query schema
61+
struct<model:variant>
62+
-- !query output
63+
"basic"
64+
65+
66+
-- !query
67+
select substr(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model, 2, 3)
68+
-- !query schema
69+
struct<substr(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[0].model) AS model, 2, 3):string>
70+
-- !query output
71+
asi
72+
73+
74+
-- !query
75+
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double
76+
-- !query schema
77+
struct<price:double>
78+
-- !query output
79+
9.24
80+
81+
82+
-- !query
83+
select ceil(sqrt(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double))
84+
-- !query schema
85+
struct<CEIL(SQRT(CAST(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[1].price) AS price AS DOUBLE))):bigint>
86+
-- !query output
87+
4

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.AnalysisException
2828
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDataSource, ResolveSessionCatalog, ResolveTranspose}
2929
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
3030
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException}
31-
import org.apache.spark.sql.catalyst.expressions.Expression
31+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields}
3232
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3333
import org.apache.spark.sql.catalyst.rules.Rule
3434
import org.apache.spark.sql.classic.{SparkSession, Strategy}
@@ -133,6 +133,7 @@ class HiveSessionStateBuilder(
133133
new DetermineTableStats(session) +:
134134
new ResolveTranspose(session) +:
135135
new InvokeProcedures(session) +:
136+
ExtractSemiStructuredFields +:
136137
customResolutionRules
137138

138139
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =

0 commit comments

Comments
 (0)