From 1cc9ea896221370591b83f7b6d28d958a34254a4 Mon Sep 17 00:00:00 2001 From: Leonid Dubinsky Date: Sat, 24 Dec 2022 22:20:47 -0500 Subject: [PATCH] `scala.xml.XML` allows overriding the SAXParser used via the `withSAXParser()` method. Some XML parsing customizations require changing the XMLReader contained inside every SAXParser (e.g., adding an XMLFilter). This pull request introduces an additional extension point `XMLLoader.reader` and a method `XML.withXMLReader()` for such a purpose. Also, ErrorHandler and EntityResolver configured externally are no longer wiped out before parsing the XML. --- jvm/src/test/scala/scala/xml/XMLTest.scala | 28 +++++++ shared/src/main/scala/scala/xml/XML.scala | 8 +- .../scala/scala/xml/factory/XMLLoader.scala | 74 ++++++++++++------- shared/src/main/scala/scala/xml/package.scala | 1 + .../scala/xml/parsing/MarkupParser.scala | 5 +- 5 files changed, 84 insertions(+), 32 deletions(-) diff --git a/jvm/src/test/scala/scala/xml/XMLTest.scala b/jvm/src/test/scala/scala/xml/XMLTest.scala index 812b53680..032b341a6 100644 --- a/jvm/src/test/scala/scala/xml/XMLTest.scala +++ b/jvm/src/test/scala/scala/xml/XMLTest.scala @@ -657,6 +657,34 @@ class XMLTestJVM { def namespaceAware2: Unit = roundtrip(namespaceAware = true, """""") + @UnitTest + def useXMLReaderWithXMLFilter(): Unit = { + val parent: org.xml.sax.XMLReader = javax.xml.parsers.SAXParserFactory.newInstance.newSAXParser.getXMLReader + val filter: org.xml.sax.XMLFilter = new org.xml.sax.helpers.XMLFilterImpl(parent) { + override def characters(ch: Array[Char], start: Int, length: Int): Unit = { + for (i <- 0 until length) if (ch(start+i) == 'a') ch(start+i) = 'b' + super.characters(ch, start, length) + } + } + assertEquals(XML.withXMLReader(filter).loadString("caffeeaaay").toString, "cbffeebbby") + } + + @UnitTest + def checkThatErrorHandlerIsNotOverwritten(): Unit = { + var gotAnError: Boolean = false + XML.reader.setErrorHandler(new org.xml.sax.ErrorHandler { + override def warning(e: SAXParseException): Unit = gotAnError = true + override def error(e: SAXParseException): Unit = gotAnError = true + override def fatalError(e: SAXParseException): Unit = gotAnError = true + }) + try { + XML.loadString("") + } catch { + case _: org.xml.sax.SAXParseException => + } + assertTrue(gotAnError) + } + @UnitTest def nodeSeqNs: Unit = { val x = { diff --git a/shared/src/main/scala/scala/xml/XML.scala b/shared/src/main/scala/scala/xml/XML.scala index 46b5a0ec6..531afcc97 100755 --- a/shared/src/main/scala/scala/xml/XML.scala +++ b/shared/src/main/scala/scala/xml/XML.scala @@ -14,8 +14,8 @@ package scala package xml import factory.XMLLoader -import java.io.{ File, FileDescriptor, FileInputStream, FileOutputStream } -import java.io.{ InputStream, Reader, StringReader } +import java.io.{File, FileDescriptor, FileInputStream, FileOutputStream} +import java.io.{InputStream, Reader, StringReader} import java.nio.channels.Channels import scala.util.control.Exception.ultimately @@ -72,6 +72,10 @@ object XML extends XMLLoader[Elem] { def withSAXParser(p: SAXParser): XMLLoader[Elem] = new XMLLoader[Elem] { override val parser: SAXParser = p } + /** Returns an XMLLoader whose load* methods will use the supplied XMLReader. */ + def withXMLReader(r: XMLReader): XMLLoader[Elem] = + new XMLLoader[Elem] { override val reader: XMLReader = r } + /** * Saves a node to a file with given filename using given encoding * optionally with xmldecl and doctype declaration. diff --git a/shared/src/main/scala/scala/xml/factory/XMLLoader.scala b/shared/src/main/scala/scala/xml/factory/XMLLoader.scala index 620e1b6e0..5cdc9461e 100644 --- a/shared/src/main/scala/scala/xml/factory/XMLLoader.scala +++ b/shared/src/main/scala/scala/xml/factory/XMLLoader.scala @@ -14,7 +14,7 @@ package scala package xml package factory -import org.xml.sax.SAXNotRecognizedException +import org.xml.sax.{SAXNotRecognizedException, XMLReader} import javax.xml.parsers.SAXParserFactory import parsing.{FactoryAdapter, NoBindingFactoryAdapter} import java.io.{File, FileDescriptor, InputStream, Reader} @@ -46,59 +46,77 @@ trait XMLLoader[T <: Node] { /* Override this to use a different SAXParser. */ def parser: SAXParser = parserInstance.get + /* Override this to use a different XMLReader. */ + def reader: XMLReader = parser.getXMLReader + /** * Loads XML from the given InputSource, using the supplied parser. * The methods available in scala.xml.XML use the XML parser in the JDK. */ - def loadXML(source: InputSource, parser: SAXParser): T = { - val result: FactoryAdapter = parse(source, parser) + def loadXML(source: InputSource, parser: SAXParser): T = loadXML(source, parser.getXMLReader) + + def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = loadXMLNodes(source, parser.getXMLReader) + + private def loadXML(source: InputSource, reader: XMLReader): T = { + val result: FactoryAdapter = parse(source, reader) result.rootElem.asInstanceOf[T] } - - def loadXMLNodes(source: InputSource, parser: SAXParser): Seq[Node] = { - val result: FactoryAdapter = parse(source, parser) + + private def loadXMLNodes(source: InputSource, reader: XMLReader): Seq[Node] = { + val result: FactoryAdapter = parse(source, reader) result.prolog ++ (result.rootElem :: result.epilogue) } - private def parse(source: InputSource, parser: SAXParser): FactoryAdapter = { + private def parse(source: InputSource, reader: XMLReader): FactoryAdapter = { + if (source == null) throw new IllegalArgumentException("InputSource cannot be null") + val result: FactoryAdapter = adapter + reader.setContentHandler(result) + reader.setDTDHandler(result) + /* Do not overwrite pre-configured EntityResolver. */ + if (reader.getEntityResolver == null) reader.setEntityResolver(result) + /* Do not overwrite pre-configured ErrorHandler. */ + if (reader.getErrorHandler == null) reader.setErrorHandler(result) + try { - parser.setProperty("http://xml.org/sax/properties/lexical-handler", result) + reader.setProperty("http://xml.org/sax/properties/lexical-handler", result) } catch { case _: SAXNotRecognizedException => } result.scopeStack = TopScope :: result.scopeStack - parser.parse(source, result) + reader.parse(source) result.scopeStack = result.scopeStack.tail result } + /** loads XML from given InputSource. */ + def load(source: InputSource): T = loadXML(source, reader) + /** Loads XML from the given file, file descriptor, or filename. */ - def loadFile(file: File): T = loadXML(fromFile(file), parser) - def loadFile(fd: FileDescriptor): T = loadXML(fromFile(fd), parser) - def loadFile(name: String): T = loadXML(fromFile(name), parser) + def loadFile(file: File): T = load(fromFile(file)) + def loadFile(fd: FileDescriptor): T = load(fromFile(fd)) + def loadFile(name: String): T = load(fromFile(name)) - /** loads XML from given InputStream, Reader, sysID, InputSource, or URL. */ - def load(is: InputStream): T = loadXML(fromInputStream(is), parser) - def load(reader: Reader): T = loadXML(fromReader(reader), parser) - def load(sysID: String): T = loadXML(fromSysId(sysID), parser) - def load(source: InputSource): T = loadXML(source, parser) - def load(url: URL): T = loadXML(fromInputStream(url.openStream()), parser) + /** loads XML from given InputStream, Reader, sysID, or URL. */ + def load(is: InputStream): T = load(fromInputStream(is)) + def load(reader: Reader): T = load(fromReader(reader)) + def load(sysID: String): T = load(fromSysId(sysID)) + def load(url: URL): T = load(fromInputStream(url.openStream())) /** Loads XML from the given String. */ - def loadString(string: String): T = loadXML(fromString(string), parser) + def loadString(string: String): T = load(fromString(string)) /** Load XML nodes, including comments and processing instructions that precede and follow the root element. */ - def loadFileNodes(file: File): Seq[Node] = loadXMLNodes(fromFile(file), parser) - def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadXMLNodes(fromFile(fd), parser) - def loadFileNodes(name: String): Seq[Node] = loadXMLNodes(fromFile(name), parser) - def loadNodes(is: InputStream): Seq[Node] = loadXMLNodes(fromInputStream(is), parser) - def loadNodes(reader: Reader): Seq[Node] = loadXMLNodes(fromReader(reader), parser) - def loadNodes(sysID: String): Seq[Node] = loadXMLNodes(fromSysId(sysID), parser) - def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, parser) - def loadNodes(url: URL): Seq[Node] = loadXMLNodes(fromInputStream(url.openStream()), parser) - def loadStringNodes(string: String): Seq[Node] = loadXMLNodes(fromString(string), parser) + def loadNodes(source: InputSource): Seq[Node] = loadXMLNodes(source, reader) + def loadFileNodes(file: File): Seq[Node] = loadNodes(fromFile(file)) + def loadFileNodes(fd: FileDescriptor): Seq[Node] = loadNodes(fromFile(fd)) + def loadFileNodes(name: String): Seq[Node] = loadNodes(fromFile(name)) + def loadNodes(is: InputStream): Seq[Node] = loadNodes(fromInputStream(is)) + def loadNodes(reader: Reader): Seq[Node] = loadNodes(fromReader(reader)) + def loadNodes(sysID: String): Seq[Node] = loadNodes(fromSysId(sysID)) + def loadNodes(url: URL): Seq[Node] = loadNodes(fromInputStream(url.openStream())) + def loadStringNodes(string: String): Seq[Node] = loadNodes(fromString(string)) } diff --git a/shared/src/main/scala/scala/xml/package.scala b/shared/src/main/scala/scala/xml/package.scala index 7847f63ba..d25a80e27 100644 --- a/shared/src/main/scala/scala/xml/package.scala +++ b/shared/src/main/scala/scala/xml/package.scala @@ -80,5 +80,6 @@ package object xml { type SAXParseException = org.xml.sax.SAXParseException type EntityResolver = org.xml.sax.EntityResolver type InputSource = org.xml.sax.InputSource + type XMLReader = org.xml.sax.XMLReader type SAXParser = javax.xml.parsers.SAXParser } diff --git a/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala b/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala index 853897117..4fe936a4b 100755 --- a/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala +++ b/shared/src/main/scala/scala/xml/parsing/MarkupParser.scala @@ -98,8 +98,9 @@ trait MarkupParser extends MarkupParserCommon with TokenTests { var extIndex = -1 /** holds temporary values of pos */ - // Note: this is clearly an override, but if marked as such it causes a "...cannot override a mutable variable" - // error with Scala 3; does it work with Scala 3 if not explicitly marked as an override remains to be seen... + // Note: if marked as an override, this causes a "...cannot override a mutable variable" error with Scala 3; + // SethTisue noted on Oct 14, 2021 that lampepfl/dotty#13744 should fix it - and it probably did, + // but Scala XML still builds against Scala 3 version that has this bug, so this still can not be marked as an override :( var tmppos: Int = _ /** holds the next character */