Skip to content

Commit 5d88f6a

Browse files
Add gpu iceberg partitioner (#13415)
Fixes #13377 . ### Description This pr add gpu implementation of iceberg partitioner. ### Checklists - [x] This PR has added documentation for new or modified features or behaviors. - [x] This PR has added new tests or modified existing tests to cover new code paths. (Please explain in the PR description how the new code paths are tested, such as names of the new/existing tests that cover them.) - [ ] Performance testing has been performed and its results are added in the PR description. Or, an issue has been filed with a link in the PR description. --------- Signed-off-by: Ray Liu <[email protected]>
1 parent a56530b commit 5d88f6a

File tree

3 files changed

+393
-1
lines changed

3 files changed

+393
-1
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.nvidia.spark.rapids.iceberg
18+
19+
import java.lang.Math.toIntExact
20+
21+
import scala.collection.JavaConverters._
22+
23+
import ai.rapids.cudf.{ColumnVector => CudfColumnVector, OrderByArg, Scalar, Table}
24+
import com.nvidia.spark.rapids.{GpuBoundReference, GpuColumnVector, GpuExpression, GpuLiteral, RapidsHostColumnVector, SpillableColumnarBatch, SpillPriorities}
25+
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
26+
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq
27+
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit
28+
import com.nvidia.spark.rapids.SpillPriorities.ACTIVE_ON_DECK_PRIORITY
29+
import com.nvidia.spark.rapids.iceberg.GpuIcebergPartitioner.toPartitionKeys
30+
import org.apache.iceberg.{PartitionField, PartitionSpec, Schema, StructLike}
31+
import org.apache.iceberg.spark.{GpuTypeToSparkType, SparkStructLike}
32+
import org.apache.iceberg.spark.functions.GpuBucketExpression
33+
import org.apache.iceberg.types.Types
34+
35+
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
36+
import org.apache.spark.sql.catalyst.expressions.NamedExpression.newExprId
37+
import org.apache.spark.sql.types.{DataType, StructType}
38+
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
39+
40+
/**
41+
* A GPU based Iceberg partitioner that partitions the input columnar batch into multiple
42+
* columnar batches based on the given partition spec.
43+
*
44+
* @param spec the iceberg partition spec
45+
* @param dataSparkType the spark schema of the input data
46+
*/
47+
class GpuIcebergPartitioner(val spec: PartitionSpec,
48+
val dataSparkType: StructType) {
49+
require(spec.isPartitioned, "Should not create a partitioner for unpartitioned table")
50+
private val inputSchema: Schema = spec.schema()
51+
private val sparkType: Array[DataType] = dataSparkType.fields.map(_.dataType)
52+
private val partitionSparkType: StructType = GpuTypeToSparkType.toSparkType(spec.partitionType())
53+
54+
private val partitionExprs: Seq[GpuExpression] = spec.fields().asScala.map(getPartitionExpr).toSeq
55+
56+
private val keyColNum: Int = spec.fields().size()
57+
private val keyColIndices: Array[Int] = (0 until keyColNum).toArray
58+
private val keySortOrders: Array[OrderByArg] = (0 until keyColNum)
59+
.map(OrderByArg.asc(_, true))
60+
.toArray
61+
62+
/**
63+
* Partition the `input` columnar batch using iceberg's partition spec.
64+
* <br/>
65+
* This method takes the ownership of the input columnar batch, and it should not be used after
66+
* this call.
67+
*/
68+
def partition(input: ColumnarBatch): Seq[ColumnarBatchWithPartition] = {
69+
if (input.numRows() == 0) {
70+
return Seq.empty
71+
}
72+
73+
val numRows = input.numRows()
74+
75+
val spillableInput = closeOnExcept(input) { _ =>
76+
SpillableColumnarBatch(input, ACTIVE_ON_DECK_PRIORITY)
77+
}
78+
79+
val (partitionKeys, partitions) = withRetryNoSplit(spillableInput) { scb =>
80+
val parts = withResource(scb.getColumnarBatch()) { inputBatch =>
81+
partitionExprs.safeMap(_.columnarEval(inputBatch))
82+
}
83+
val keysTable = withResource(parts) { _ =>
84+
val arr = new Array[CudfColumnVector](partitionExprs.size)
85+
for (i <- partitionExprs.indices) {
86+
arr(i) = parts(i).getBase
87+
}
88+
new Table(arr:_*)
89+
}
90+
91+
val sortedKeyTableWithRowIdx = withResource(keysTable) { _ =>
92+
withResource(Scalar.fromInt(0)) { zero =>
93+
withResource(CudfColumnVector.sequence(zero, numRows)) { rowIdxCol =>
94+
val totalColCount = keysTable.getNumberOfColumns + 1
95+
val allCols = new Array[CudfColumnVector](totalColCount)
96+
97+
for (i <- 0 until keysTable.getNumberOfColumns) {
98+
allCols(i) = keysTable.getColumn(i)
99+
}
100+
allCols(keysTable.getNumberOfColumns) = rowIdxCol
101+
102+
withResource(new Table(allCols: _*)) { allColsTable =>
103+
allColsTable.orderBy(keySortOrders: _*)
104+
}
105+
}
106+
}
107+
}
108+
109+
val (sortedPartitionKeys, splitIds, rowIdxCol) = withResource(sortedKeyTableWithRowIdx) { _ =>
110+
val uniqueKeysTable = sortedKeyTableWithRowIdx.groupBy(keyColIndices: _*)
111+
.aggregate()
112+
113+
val sortedUniqueKeysTable = withResource(uniqueKeysTable) { _ =>
114+
uniqueKeysTable.orderBy(keySortOrders: _*)
115+
}
116+
117+
val (sortedPartitionKeys, splitIds) = withResource(sortedUniqueKeysTable) { _ =>
118+
val partitionKeys = toPartitionKeys(spec.partitionType(),
119+
partitionSparkType,
120+
sortedUniqueKeysTable)
121+
122+
val splitIdsCv = sortedKeyTableWithRowIdx.upperBound(
123+
sortedUniqueKeysTable,
124+
keySortOrders: _*)
125+
126+
val splitIds = withResource(splitIdsCv) { _ =>
127+
GpuColumnVector.toIntArray(splitIdsCv)
128+
}
129+
130+
(partitionKeys, splitIds)
131+
}
132+
133+
val rowIdxCol = sortedKeyTableWithRowIdx.getColumn(keyColNum).incRefCount()
134+
(sortedPartitionKeys, splitIds, rowIdxCol)
135+
}
136+
137+
withResource(rowIdxCol) { _ =>
138+
val inputTable = withResource(scb.getColumnarBatch()) { inputBatch =>
139+
GpuColumnVector.from(inputBatch)
140+
}
141+
142+
val sortedDataTable = withResource(inputTable) { _ =>
143+
inputTable.gather(rowIdxCol)
144+
}
145+
146+
val partitions = withResource(sortedDataTable) { _ =>
147+
sortedDataTable.contiguousSplit(splitIds: _*)
148+
}
149+
150+
(sortedPartitionKeys, partitions)
151+
}
152+
}
153+
154+
withResource(partitions) { _ =>
155+
partitionKeys.zip(partitions).map { case (partKey, partition) =>
156+
ColumnarBatchWithPartition(SpillableColumnarBatch(partition, sparkType, SpillPriorities
157+
.ACTIVE_BATCHING_PRIORITY), partKey)
158+
}.toSeq
159+
}
160+
161+
}
162+
163+
private def getPartitionExpr(field: PartitionField)
164+
: GpuExpression = {
165+
val transform = field.transform()
166+
val inputIndex = fieldIndex(inputSchema, field.sourceId())
167+
val sparkField = dataSparkType.fields(inputIndex)
168+
val inputRefExpr = GpuBoundReference(inputIndex, sparkField.dataType,
169+
sparkField.nullable)(newExprId, s"input$inputIndex")
170+
171+
transform.toString match {
172+
// bucket transform is like "bucket[16]"
173+
case s if s.startsWith("bucket") =>
174+
val bucket = s.substring("bucket[".length, s.length - 1).toInt
175+
GpuBucketExpression(GpuLiteral.create(bucket), inputRefExpr)
176+
case other =>
177+
throw new IllegalArgumentException(s"Unsupported transform: $other")
178+
}
179+
}
180+
}
181+
182+
case class ColumnarBatchWithPartition(batch: SpillableColumnarBatch, partition: StructLike) extends
183+
AutoCloseable {
184+
override def close(): Unit = {
185+
batch.close()
186+
}
187+
}
188+
189+
object GpuIcebergPartitioner {
190+
191+
private def toPartitionKeys(icebergType: Types.StructType,
192+
sparkType: StructType,
193+
table: Table): Array[SparkStructLike] = {
194+
val numCols = table.getNumberOfColumns
195+
val numRows = toIntExact(table.getRowCount)
196+
197+
val hostColsArray = closeOnExcept(new Array[ColumnVector](numCols)) { hostCols =>
198+
for (colIdx <- 0 until numCols) {
199+
hostCols(colIdx) = new RapidsHostColumnVector(sparkType.fields(colIdx).dataType,
200+
table.getColumn(colIdx).copyToHost())
201+
}
202+
hostCols
203+
}
204+
205+
withResource(new ColumnarBatch(hostColsArray, numRows)) { hostBatch =>
206+
hostBatch.rowIterator()
207+
.asScala
208+
.map(internalRow => {
209+
val row = new GenericRowWithSchema(internalRow.toSeq(sparkType).toArray, sparkType)
210+
new SparkStructLike(icebergType).wrap(row)
211+
}).toArray
212+
}
213+
}
214+
}

iceberg/src/main/scala/org/apache/iceberg/spark/GpuTypeToSparkType.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717
package org.apache.iceberg.spark
1818

1919
import org.apache.iceberg.Schema
20-
import org.apache.iceberg.types.TypeUtil
20+
import org.apache.iceberg.types.{Types, TypeUtil}
2121

2222
import org.apache.spark.sql.types.StructType
2323

2424
object GpuTypeToSparkType {
2525
def toSparkType(schema: Schema): StructType = {
2626
TypeUtil.visit(schema, new TypeToSparkType).asInstanceOf[StructType]
2727
}
28+
29+
def toSparkType(icebergStruct: Types.StructType): StructType = {
30+
TypeUtil.visit(icebergStruct, new TypeToSparkType).asInstanceOf[StructType]
31+
}
2832
}

0 commit comments

Comments
 (0)