Skip to content

Commit 6d78d43

Browse files
olakyallisonport-db
authored andcommitted
Use storageAssighmentPolicy for casts in DML commands
Follow spark.sql.storeAssignmentPolicy instead of spark.sql.ansi.enabled for casting behaviour in UPDATE and MERGE. This will by default error out at runtime when an overflow happens. Closes #1938 GitOrigin-RevId: c960a0521df27daa6ee231e0a1022d8756496785
1 parent 0626664 commit 6d78d43

File tree

8 files changed

+458
-7
lines changed

8 files changed

+458
-7
lines changed

spark/src/main/resources/error/delta-error-classes.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,14 @@
272272
],
273273
"sqlState" : "0A000"
274274
},
275+
"DELTA_CAST_OVERFLOW_IN_TABLE_WRITE" : {
276+
"message" : [
277+
"Failed to write a value of <sourceType> type into the <targetType> type column <columnName> due to an overflow.",
278+
"Use `try_cast` on the input value to tolerate overflow and return NULL instead.",
279+
"If necessary, set <storeAssignmentPolicyFlag> to \"LEGACY\" to bypass this error or set <updateAndMergeCastingFollowsAnsiEnabledFlag> to true to revert to the old behaviour and follow <ansiEnabledFlag> in UPDATE and MERGE."
280+
],
281+
"sqlState" : "22003"
282+
},
275283
"DELTA_CDC_NOT_ALLOWED_IN_THIS_VERSION" : {
276284
"message" : [
277285
"Configuration delta.enableChangeDataFeed cannot be set. Change data feed from Delta is not yet available."

spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
4444
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
4545
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
4646
import org.apache.spark.sql.connector.catalog.Identifier
47+
import org.apache.spark.sql.errors.QueryErrorsBase
4748
import org.apache.spark.sql.internal.SQLConf
4849
import org.apache.spark.sql.types.{DataType, StructField, StructType}
4950

@@ -118,7 +119,8 @@ trait DocsPath {
118119
*/
119120
trait DeltaErrorsBase
120121
extends DocsPath
121-
with DeltaLogging {
122+
with DeltaLogging
123+
with QueryErrorsBase {
122124

123125
def baseDocsPath(spark: SparkSession): String = baseDocsPath(spark.sparkContext.getConf)
124126

@@ -618,6 +620,22 @@ trait DeltaErrorsBase
618620
)
619621
}
620622

623+
def castingCauseOverflowErrorInTableWrite(
624+
from: DataType,
625+
to: DataType,
626+
columnName: String): ArithmeticException = {
627+
new DeltaArithmeticException(
628+
errorClass = "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE",
629+
messageParameters = Map(
630+
"sourceType" -> toSQLType(from),
631+
"targetType" -> toSQLType(to),
632+
"columnName" -> toSQLId(columnName),
633+
"storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key,
634+
"updateAndMergeCastingFollowsAnsiEnabledFlag" ->
635+
DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key,
636+
"ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key))
637+
}
638+
621639
def notADeltaTable(table: String): Throwable = {
622640
new DeltaAnalysisException(errorClass = "DELTA_NOT_A_DELTA_TABLE",
623641
messageParameters = Array(table))

spark/src/main/scala/org/apache/spark/sql/delta/DeltaSharedExceptions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,11 @@ class DeltaParseException(
8181
ParserUtils.position(ctx.getStop)
8282
) with DeltaThrowable
8383

84+
class DeltaArithmeticException(
85+
errorClass: String,
86+
messageParameters: Map[String, String]) extends ArithmeticException with DeltaThrowable {
87+
override def getErrorClass: String = errorClass
88+
89+
override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava
90+
}
91+

spark/src/main/scala/org/apache/spark/sql/delta/UpdateExpressionsSupport.scala

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf
2121
import org.apache.spark.sql.delta.util.AnalysisHelper
2222

2323
import org.apache.spark.sql.SparkSession
24-
import org.apache.spark.sql.catalyst.SQLConfHelper
24+
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
2525
import org.apache.spark.sql.catalyst.analysis.Resolver
2626
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
28+
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2729
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
30+
import org.apache.spark.sql.internal.SQLConf
2831
import org.apache.spark.sql.types._
2932

3033
/**
@@ -405,7 +408,109 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper {
405408
}
406409
}
407410

411+
/**
412+
* Replaces 'CastSupport.cast'. Selects a cast based on 'spark.sql.storeAssignmentPolicy' if
413+
* 'spark.databricks.delta.updateAndMergeCastingFollowsAnsiEnabledFlag. is false, and based on
414+
* 'spark.sql.ansi.enabled' otherwise.
415+
*/
408416
private def cast(child: Expression, dataType: DataType, columnName: String): Expression = {
409-
Cast(child, dataType, Option(conf.sessionLocalTimeZone))
417+
if (conf.getConf(DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG)) {
418+
return Cast(child, dataType, Option(conf.sessionLocalTimeZone))
419+
}
420+
421+
conf.storeAssignmentPolicy match {
422+
case SQLConf.StoreAssignmentPolicy.LEGACY =>
423+
Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = false)
424+
case SQLConf.StoreAssignmentPolicy.ANSI =>
425+
val cast = Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = true)
426+
if (canCauseCastOverflow(cast)) {
427+
CheckOverflowInTableWrite(cast, columnName)
428+
} else {
429+
cast
430+
}
431+
case SQLConf.StoreAssignmentPolicy.STRICT =>
432+
UpCast(child, dataType)
433+
}
434+
}
435+
436+
private def containsIntegralOrDecimalType(dt: DataType): Boolean = dt match {
437+
case _: IntegralType | _: DecimalType => true
438+
case a: ArrayType => containsIntegralOrDecimalType(a.elementType)
439+
case m: MapType =>
440+
containsIntegralOrDecimalType(m.keyType) || containsIntegralOrDecimalType(m.valueType)
441+
case s: StructType =>
442+
s.fields.exists(sf => containsIntegralOrDecimalType(sf.dataType))
443+
case _ => false
444+
}
445+
446+
private def canCauseCastOverflow(cast: Cast): Boolean = {
447+
containsIntegralOrDecimalType(cast.dataType) &&
448+
!Cast.canUpCast(cast.child.dataType, cast.dataType)
449+
}
450+
}
451+
452+
case class CheckOverflowInTableWrite(child: Expression, columnName: String)
453+
extends UnaryExpression {
454+
override protected def withNewChildInternal(newChild: Expression): Expression = {
455+
copy(child = newChild)
410456
}
457+
458+
private def getCast: Option[Cast] = child match {
459+
case c: Cast => Some(c)
460+
case ExpressionProxy(c: Cast, _, _) => Some(c)
461+
case _ => None
462+
}
463+
464+
override def eval(input: InternalRow): Any = try {
465+
child.eval(input)
466+
} catch {
467+
case e: ArithmeticException =>
468+
getCast match {
469+
case Some(cast) =>
470+
throw DeltaErrors.castingCauseOverflowErrorInTableWrite(
471+
cast.child.dataType,
472+
cast.dataType,
473+
columnName)
474+
case None => throw e
475+
}
476+
}
477+
478+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
479+
getCast match {
480+
case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child)
481+
case None => child.genCode(ctx)
482+
}
483+
}
484+
485+
def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: Cast): ExprCode = {
486+
val childGen = child.genCode(ctx)
487+
val exceptionClass = classOf[ArithmeticException].getCanonicalName
488+
assert(child.isInstanceOf[Cast])
489+
val cast = child.asInstanceOf[Cast]
490+
val fromDt =
491+
ctx.addReferenceObj("from", cast.child.dataType, cast.child.dataType.getClass.getName)
492+
val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName)
493+
val col = ctx.addReferenceObj("colName", columnName, "java.lang.String")
494+
// scalastyle:off line.size.limit
495+
ev.copy(code =
496+
code"""
497+
boolean ${ev.isNull} = true;
498+
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
499+
try {
500+
${childGen.code}
501+
${ev.isNull} = ${childGen.isNull};
502+
${ev.value} = ${childGen.value};
503+
} catch ($exceptionClass e) {
504+
throw org.apache.spark.sql.delta.DeltaErrors
505+
.castingCauseOverflowErrorInTableWrite($fromDt, $toDt, $col);
506+
}"""
507+
)
508+
// scalastyle:on line.size.limit
509+
}
510+
511+
override def dataType: DataType = child.dataType
512+
513+
override def sql: String = child.sql
514+
515+
override def toString: String = child.toString
411516
}

spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.internal.config.ConfigBuilder
2323
import org.apache.spark.network.util.ByteUnit
2424
import org.apache.spark.sql.internal.SQLConf
2525
import org.apache.spark.storage.StorageLevel
26-
import org.apache.spark.util.Utils
2726

2827
/**
2928
* [[SQLConf]] entries for Delta features.
@@ -1254,6 +1253,15 @@ trait DeltaSQLConfBase {
12541253
.intConf
12551254
.createWithDefault(100 * 1000)
12561255

1256+
val UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG =
1257+
buildConf("updateAndMergeCastingFollowsAnsiEnabledFlag")
1258+
.internal()
1259+
.doc("""If false, casting behaviour in implicit casts in UPDATE and MERGE follows
1260+
|'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'. This
1261+
|was the default before Delta 3.5.""".stripMargin)
1262+
.booleanConf
1263+
.createWithDefault(false)
1264+
12571265
}
12581266

12591267
object DeltaSQLConf extends DeltaSQLConfBase

spark/src/test/scala/org/apache/spark/sql/delta/DeltaErrorsSuite.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ import org.apache.spark.sql.catalyst.expressions.Uuid
5252
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
5353
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
5454
import org.apache.spark.sql.connector.catalog.Identifier
55+
import org.apache.spark.sql.errors.QueryErrorsBase
5556
import org.apache.spark.sql.internal.SQLConf
5657
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
5758
import org.apache.spark.sql.types.{CalendarIntervalType, DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampNTZType}
@@ -60,7 +61,8 @@ trait DeltaErrorsSuiteBase
6061
extends QueryTest
6162
with SharedSparkSession with GivenWhenThen
6263
with DeltaSQLCommandTest
63-
with SQLTestUtils {
64+
with SQLTestUtils
65+
with QueryErrorsBase {
6466

6567
val MAX_URL_ACCESS_RETRIES = 3
6668
val path = "/sample/path"
@@ -288,6 +290,24 @@ trait DeltaErrorsSuiteBase
288290
assert(
289291
e.getMessage == s"$table is a view. Writes to a view are not supported.")
290292
}
293+
{
294+
val sourceType = IntegerType
295+
val targetType = DateType
296+
val columnName = "column_name"
297+
val e = intercept[DeltaArithmeticException] {
298+
throw DeltaErrors.castingCauseOverflowErrorInTableWrite(sourceType, targetType, columnName)
299+
}
300+
assert(e.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE")
301+
assert(e.getSqlState == "22003")
302+
assert(e.getMessageParameters.get("sourceType") == toSQLType(sourceType))
303+
assert(e.getMessageParameters.get("targetType") == toSQLType(targetType))
304+
assert(e.getMessageParameters.get("columnName") == toSQLId(columnName))
305+
assert(e.getMessageParameters.get("storeAssignmentPolicyFlag")
306+
== SQLConf.STORE_ASSIGNMENT_POLICY.key)
307+
assert(e.getMessageParameters.get("updateAndMergeCastingFollowsAnsiEnabledFlag")
308+
== DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key)
309+
assert(e.getMessageParameters.get("ansiEnabledFlag") == SQLConf.ANSI_ENABLED.key)
310+
}
291311
{
292312
val e = intercept[DeltaAnalysisException] {
293313
throw DeltaErrors.invalidColumnName(name = "col-1")

0 commit comments

Comments
 (0)