diff --git a/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala b/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala index bd2eb820d..4258c100f 100644 --- a/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala +++ b/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala @@ -1,7 +1,7 @@ package scala.xml import scala.collection.immutable.StrictOptimizedSeqOps -import scala.collection.{SeqOps, IterableOnce, immutable, mutable} +import scala.collection.{View, SeqOps, IterableOnce, immutable, mutable} import scala.collection.BuildFrom import scala.collection.mutable.Builder @@ -20,6 +20,21 @@ private[xml] trait ScalaVersionSpecificNodeSeq override def fromSpecific(coll: IterableOnce[Node]): NodeSeq = (NodeSeq.newBuilder ++= coll).result() override def newSpecificBuilder: mutable.Builder[Node, NodeSeq] = NodeSeq.newBuilder override def empty: NodeSeq = NodeSeq.Empty + def concat(suffix: IterableOnce[Node]): NodeSeq = + fromSpecific(iterator ++ suffix.iterator) + @inline final def ++ (suffix: Seq[Node]): NodeSeq = concat(suffix) + def appended(base: Node): NodeSeq = + fromSpecific(new View.Appended(this, base)) + def appendedAll(suffix: IterableOnce[Node]): NodeSeq = + concat(suffix) + def prepended(base: Node): NodeSeq = + fromSpecific(new View.Prepended(base, this)) + def prependedAll(prefix: IterableOnce[Node]): NodeSeq = + fromSpecific(prefix.iterator ++ iterator) + def map(f: Node => Node): NodeSeq = + fromSpecific(new View.Map(this, f)) + def flatMap(f: Node => IterableOnce[Node]): NodeSeq = + fromSpecific(new View.FlatMap(this, f)) } private[xml] trait ScalaVersionSpecificNodeBuffer { self: NodeBuffer => diff --git a/shared/src/test/scala/scala/xml/NodeSeqTest.scala b/shared/src/test/scala/scala/xml/NodeSeqTest.scala new file mode 100644 index 000000000..7f80b3366 --- /dev/null +++ b/shared/src/test/scala/scala/xml/NodeSeqTest.scala @@ -0,0 +1,105 @@ +package scala.xml + +import scala.xml.NodeSeq.seqToNodeSeq + +import org.junit.Test +import org.junit.Assert.assertEquals +import org.junit.Assert.fail + +class NodeSeqTest { + + @Test + def testAppend: Unit = { // Bug #392. + val a: NodeSeq = Hello + val b = Hi + a ++ Hi match { + case res: NodeSeq => assertEquals(2, res.size.toLong) + case res: Seq[Node] => fail("Should be NodeSeq was Seq[Node]") // Unreachable code? + } + val res: NodeSeq = a ++ b + val exp = NodeSeq.fromSeq(Seq(Hello, Hi)) + assertEquals(exp, res) + } + + @Test + def testAppendedAll: Unit = { // Bug #392. + val a: NodeSeq = Hello + val b = Hi + a :+ Hi match { + case res: Seq[Node] => assertEquals(2, res.size.toLong) + case res: NodeSeq => fail("Should be Seq[Node] was NodeSeq") // Unreachable code? + } + val res: NodeSeq = a :+ b + val exp = NodeSeq.fromSeq(Seq(Hello, Hi)) + assertEquals(exp, res) + } + + @Test + def testPrepended: Unit = { + val a: NodeSeq = Hello + val b = Hi + a +: Hi match { + case res: Seq[Node] => assertEquals(2, res.size.toLong) + case res: NodeSeq => fail("Should be Seq[Node] was NodeSeq") // Unreachable code? + } + val res: Seq[NodeSeq] = a +: b + val exp: NodeBuffer = { + HelloHi + } + assertEquals(exp, res) + } + + @Test + def testPrependedAll: Unit = { + val a: NodeSeq = Hello + val b = Hi + val c = Hey + a ++: Hi ++: Hey match { + case res: Seq[Node] => assertEquals(3, res.size.toLong) + case res: NodeSeq => fail("Should be Seq[Node] was NodeSeq") // Unreachable code? + } + val res: NodeSeq = a ++: b ++: c + val exp = NodeSeq.fromSeq(Seq(Hello, Hi, Hey)) + assertEquals(exp, res) + } + + @Test + def testMap: Unit = { + val a: NodeSeq = Hello + val exp: NodeSeq = Seq(Hi) + assertEquals(exp, a.map(_ => Hi)) + assertEquals(exp, for { _ <- a } yield { Hi }) + } + + @Test + def testFlatMap: Unit = { + val a: NodeSeq = Hello + val exp: NodeSeq = Seq(Hi) + assertEquals(exp, a.flatMap(_ => Seq(Hi))) + assertEquals(exp, for { b <- a; _ <- b } yield { Hi }) + assertEquals(exp, for { b <- a; c <- b; _ <- c } yield { Hi }) + } + + @Test + def testStringProjection: Unit = { + val a = + + b + + + e + e + + c + + + val res = for { + b <- a \ "b" + c <- b.child + e <- (c \ "e").headOption + } yield { + e.text.trim + } + assertEquals(Seq("e"), res) + } +}