Skip to content

[SPARK-51919][PYTHON] Allow overwriting statically registered Python Data Source #50716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/docs/source/user_guide/sql/python_data_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -520,4 +520,6 @@ The following example demonstrates how to implement a basic Data Source using Ar
Usage Notes
-----------

- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other Data Sources.
- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other non-Python Data Sources.
- It is allowed to register multiple Python Data Sources with the same name. Later registrations will overwrite earlier ones.
- To automatically register a data source, export it as ``DefaultSource`` in a top level module with name prefix ``pyspark_``. See `pyspark_huggingface <https://github.com/huggingface/pyspark_huggingface>`_ for an example.
Copy link
Contributor Author

@wengh wengh Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we mention the DefaultSource feature which was previously undocumented?

Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ class DataSourceManager extends Logging {
*/
def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = {
val normalizedName = normalize(name)
if (staticDataSourceBuilders.contains(normalizedName)) {
// Cannot overwrite static Python Data Sources.
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
}
val previousValue = runtimeDataSourceBuilders.put(normalizedName, source)
if (previousValue != null) {
logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a previously " +
log"registered data source.")
} else if (staticDataSourceBuilders.contains(normalizedName)) {
logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a statically " +
log"registered data source.")
}
}

Expand All @@ -64,11 +63,7 @@ class DataSourceManager extends Logging {
* it does not exist.
*/
def lookupDataSource(name: String): UserDefinedPythonDataSource = {
if (dataSourceExists(name)) {
val normalizedName = normalize(name)
staticDataSourceBuilders.getOrElse(
normalizedName, runtimeDataSourceBuilders.get(normalizedName))
} else {
getDataSource(name).getOrElse {
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
}
}
Expand All @@ -77,9 +72,14 @@ class DataSourceManager extends Logging {
* Checks if a data source with the specified name exists (case-insensitive).
*/
def dataSourceExists(name: String): Boolean = {
getDataSource(name).isDefined
}

private def getDataSource(name: String): Option[UserDefinedPythonDataSource] = {
val normalizedName = normalize(name)
staticDataSourceBuilders.contains(normalizedName) ||
runtimeDataSourceBuilders.containsKey(normalizedName)
// Runtime registration takes precedence over static.
Option(runtimeDataSourceBuilders.get(normalizedName))
.orElse(staticDataSourceBuilders.get(normalizedName))
}

override def clone(): DataSourceManager = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase {
assume(shouldTestPandasUDFs)
val df = spark.read.format(staticSourceName).load()
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))

// Overwrite the static source
val errorText = "static source overwritten"
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource
|
|class $staticSourceName(DataSource):
| def schema(self) -> str:
| raise Exception("$errorText")
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(
name = staticSourceName, pythonScript = dataSourceScript)
spark.dataSource.registerPython(staticSourceName, dataSource)
val err = intercept[AnalysisException] {
spark.read.format(staticSourceName).load()
}
assert(err.getMessage.contains(errorText))
}

test("simple data source") {
Expand Down