diff --git a/python/docs/source/user_guide/sql/python_data_source.rst b/python/docs/source/user_guide/sql/python_data_source.rst index 22b2a0b5f3c7b..41b76c95d5806 100644 --- a/python/docs/source/user_guide/sql/python_data_source.rst +++ b/python/docs/source/user_guide/sql/python_data_source.rst @@ -520,4 +520,6 @@ The following example demonstrates how to implement a basic Data Source using Ar Usage Notes ----------- -- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other Data Sources. +- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other non-Python Data Sources. +- It is allowed to register multiple Python Data Sources with the same name. Later registrations will overwrite earlier ones. +- To automatically register a data source, export it as ``DefaultSource`` in a top level module with name prefix ``pyspark_``. See `pyspark_huggingface `_ for an example. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 711e096ebd1f8..7a8dbab35964f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -48,14 +48,13 @@ class DataSourceManager extends Logging { */ def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = { val normalizedName = normalize(name) - if (staticDataSourceBuilders.contains(normalizedName)) { - // Cannot overwrite static Python Data Sources. - throw QueryCompilationErrors.dataSourceAlreadyExists(name) - } val previousValue = runtimeDataSourceBuilders.put(normalizedName, source) if (previousValue != null) { logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a previously " + log"registered data source.") + } else if (staticDataSourceBuilders.contains(normalizedName)) { + logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a statically " + + log"registered data source.") } } @@ -64,11 +63,7 @@ class DataSourceManager extends Logging { * it does not exist. */ def lookupDataSource(name: String): UserDefinedPythonDataSource = { - if (dataSourceExists(name)) { - val normalizedName = normalize(name) - staticDataSourceBuilders.getOrElse( - normalizedName, runtimeDataSourceBuilders.get(normalizedName)) - } else { + getDataSource(name).getOrElse { throw QueryCompilationErrors.dataSourceDoesNotExist(name) } } @@ -77,9 +72,14 @@ class DataSourceManager extends Logging { * Checks if a data source with the specified name exists (case-insensitive). */ def dataSourceExists(name: String): Boolean = { + getDataSource(name).isDefined + } + + private def getDataSource(name: String): Option[UserDefinedPythonDataSource] = { val normalizedName = normalize(name) - staticDataSourceBuilders.contains(normalizedName) || - runtimeDataSourceBuilders.containsKey(normalizedName) + // Runtime registration takes precedence over static. + Option(runtimeDataSourceBuilders.get(normalizedName)) + .orElse(staticDataSourceBuilders.get(normalizedName)) } override def clone(): DataSourceManager = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index f9eb01c10edee..d201f1890dbdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -126,6 +126,24 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { assume(shouldTestPandasUDFs) val df = spark.read.format(staticSourceName).load() checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) + + // Overwrite the static source + val errorText = "static source overwritten" + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + | + |class $staticSourceName(DataSource): + | def schema(self) -> str: + | raise Exception("$errorText") + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource( + name = staticSourceName, pythonScript = dataSourceScript) + spark.dataSource.registerPython(staticSourceName, dataSource) + val err = intercept[AnalysisException] { + spark.read.format(staticSourceName).load() + } + assert(err.getMessage.contains(errorText)) } test("simple data source") {