Skip to content

[SPARK-52188] Fix for StateDataSource where StreamExecution.RUN_ID_KEY is not set #50924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ object StateStoreProvider extends Logging {
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean,
stateSchemaProvider: Option[StateSchemaProvider]): StateStoreProvider = {
hadoopConf.set(StreamExecution.RUN_ID_KEY, providerId.queryRunId.toString)
val provider = create(storeConf.providerClass)
provider.init(providerId.storeId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey, stateSchemaProvider)
Expand Down Expand Up @@ -669,12 +670,8 @@ object StateStoreProvider extends Logging {
*/
private[state] def getRunId(hadoopConf: Configuration): String = {
val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
if (runId != null) {
runId
} else {
assert(Utils.isTesting, "Failed to find query id/batch Id in task context")
UUID.randomUUID().toString
}
assert(runId != null)
runId
}

/**
Expand Down Expand Up @@ -968,7 +965,6 @@ object StateStore extends Logging {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is redundant, nice finding.

val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey,
stateSchemaBroadcast)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.execution.datasources.v2.state

import java.util.UUID

import org.apache.hadoop.conf.Configuration
import org.scalatest.Assertions

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamExecution}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -71,14 +73,16 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB
*/
private def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = {
val provider = newStateStoreProvider()
val conf = new Configuration
conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
provider.init(
StateStoreId(checkpointDir, 0, 0),
keySchema,
valueSchema,
NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = false,
StateStoreConf(spark.sessionState.conf),
new Configuration)
conf)
provider
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state

import java.io.{File, FileWriter}
import java.nio.ByteOrder
import java.util.UUID

import org.apache.hadoop.conf.Configuration
import org.scalatest.Assertions
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, OffsetSeqLog}
import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, OffsetSeqLog, StreamExecution}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -588,14 +589,16 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
*/
private def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = {
val provider = newStateStoreProvider()
val conf = new Configuration()
conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
provider.init(
StateStoreId(checkpointDir, 0, 0),
keySchema,
valueSchema,
NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = false,
StateStoreConf(spark.sessionState.conf),
new Configuration)
conf)
provider
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.streaming.state

import java.io.File
import java.util.UUID

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
Expand All @@ -26,7 +27,7 @@ import org.scalatest.Tag
import org.apache.spark.{SparkContext, SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, ForeachWriter}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream}
import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, StreamExecution}
import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -154,6 +155,7 @@ class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider {
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
innerProvider.init(
stateStoreId,
keySchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamExecution}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -2098,6 +2098,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
useMultipleValuesPerKey: Boolean = false): RocksDBStateStoreProvider = {
val provider = new RocksDBStateStoreProvider()
val testStateSchemaProvider = new TestStateSchemaProvider
conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericm-db - is it possible to also add a data source reader test ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we no longer have the isTesting check and unconditionally assert that runId should not be null, those tests would just fail.
Without this fix, all of those tests were failing with runId being equal to null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool thanks

provider.init(
storeId,
keySchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get,
hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = {
hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
val sqlConf = getDefaultSQLConf(minDeltasForSnapshot, numOfVersToRetainInMemory)
val provider = new HDFSBackedStateStoreProvider()
provider.init(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, ValueStateImplWithTTL}
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StreamExecution, ValueStateImplWithTTL}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{TimeMode, TTLConfig, ValueState}
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -461,6 +461,7 @@ abstract class StateVariableSuiteBase extends SharedSparkSession
conf: Configuration = new Configuration,
useColumnFamilies: Boolean = false): RocksDBStateStoreProvider = {
val provider = new RocksDBStateStoreProvider()
conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
provider.init(
storeId, schemaForKeyRow, schemaForValueRow, keyStateEncoderSpec,
useColumnFamilies,
Expand Down