1+ /*
2+ * Copyright (c) 2025, NVIDIA CORPORATION.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * http://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+
17+ package com .nvidia .spark .rapids .iceberg
18+
19+ import java .lang .Math .toIntExact
20+
21+ import scala .collection .JavaConverters ._
22+
23+ import ai .rapids .cudf .{ColumnVector => CudfColumnVector , OrderByArg , Scalar , Table }
24+ import com .nvidia .spark .rapids .{GpuBoundReference , GpuColumnVector , GpuExpression , GpuLiteral , RapidsHostColumnVector , SpillableColumnarBatch , SpillPriorities }
25+ import com .nvidia .spark .rapids .Arm .{closeOnExcept , withResource }
26+ import com .nvidia .spark .rapids .RapidsPluginImplicits .AutoCloseableProducingSeq
27+ import com .nvidia .spark .rapids .RmmRapidsRetryIterator .withRetryNoSplit
28+ import com .nvidia .spark .rapids .SpillPriorities .ACTIVE_ON_DECK_PRIORITY
29+ import com .nvidia .spark .rapids .iceberg .GpuIcebergPartitioner .toPartitionKeys
30+ import org .apache .iceberg .{PartitionField , PartitionSpec , Schema , StructLike }
31+ import org .apache .iceberg .spark .{GpuTypeToSparkType , SparkStructLike }
32+ import org .apache .iceberg .spark .functions .GpuBucketExpression
33+ import org .apache .iceberg .types .Types
34+
35+ import org .apache .spark .sql .catalyst .expressions .GenericRowWithSchema
36+ import org .apache .spark .sql .catalyst .expressions .NamedExpression .newExprId
37+ import org .apache .spark .sql .types .{DataType , StructType }
38+ import org .apache .spark .sql .vectorized .{ColumnarBatch , ColumnVector }
39+
40+ /**
41+ * A GPU based Iceberg partitioner that partitions the input columnar batch into multiple
42+ * columnar batches based on the given partition spec.
43+ *
44+ * @param spec the iceberg partition spec
45+ * @param dataSparkType the spark schema of the input data
46+ */
47+ class GpuIcebergPartitioner (val spec : PartitionSpec ,
48+ val dataSparkType : StructType ) {
49+ require(spec.isPartitioned, " Should not create a partitioner for unpartitioned table" )
50+ private val inputSchema : Schema = spec.schema()
51+ private val sparkType : Array [DataType ] = dataSparkType.fields.map(_.dataType)
52+ private val partitionSparkType : StructType = GpuTypeToSparkType .toSparkType(spec.partitionType())
53+
54+ private val partitionExprs : Seq [GpuExpression ] = spec.fields().asScala.map(getPartitionExpr).toSeq
55+
56+ private val keyColNum : Int = spec.fields().size()
57+ private val keyColIndices : Array [Int ] = (0 until keyColNum).toArray
58+ private val keySortOrders : Array [OrderByArg ] = (0 until keyColNum)
59+ .map(OrderByArg .asc(_, true ))
60+ .toArray
61+
62+ /**
63+ * Partition the `input` columnar batch using iceberg's partition spec.
64+ * <br/>
65+ * This method takes the ownership of the input columnar batch, and it should not be used after
66+ * this call.
67+ */
68+ def partition (input : ColumnarBatch ): Seq [ColumnarBatchWithPartition ] = {
69+ if (input.numRows() == 0 ) {
70+ return Seq .empty
71+ }
72+
73+ val numRows = input.numRows()
74+
75+ val spillableInput = closeOnExcept(input) { _ =>
76+ SpillableColumnarBatch (input, ACTIVE_ON_DECK_PRIORITY )
77+ }
78+
79+ val (partitionKeys, partitions) = withRetryNoSplit(spillableInput) { scb =>
80+ val parts = withResource(scb.getColumnarBatch()) { inputBatch =>
81+ partitionExprs.safeMap(_.columnarEval(inputBatch))
82+ }
83+ val keysTable = withResource(parts) { _ =>
84+ val arr = new Array [CudfColumnVector ](partitionExprs.size)
85+ for (i <- partitionExprs.indices) {
86+ arr(i) = parts(i).getBase
87+ }
88+ new Table (arr:_* )
89+ }
90+
91+ val sortedKeyTableWithRowIdx = withResource(keysTable) { _ =>
92+ withResource(Scalar .fromInt(0 )) { zero =>
93+ withResource(CudfColumnVector .sequence(zero, numRows)) { rowIdxCol =>
94+ val totalColCount = keysTable.getNumberOfColumns + 1
95+ val allCols = new Array [CudfColumnVector ](totalColCount)
96+
97+ for (i <- 0 until keysTable.getNumberOfColumns) {
98+ allCols(i) = keysTable.getColumn(i)
99+ }
100+ allCols(keysTable.getNumberOfColumns) = rowIdxCol
101+
102+ withResource(new Table (allCols : _* )) { allColsTable =>
103+ allColsTable.orderBy(keySortOrders : _* )
104+ }
105+ }
106+ }
107+ }
108+
109+ val (sortedPartitionKeys, splitIds, rowIdxCol) = withResource(sortedKeyTableWithRowIdx) { _ =>
110+ val uniqueKeysTable = sortedKeyTableWithRowIdx.groupBy(keyColIndices : _* )
111+ .aggregate()
112+
113+ val sortedUniqueKeysTable = withResource(uniqueKeysTable) { _ =>
114+ uniqueKeysTable.orderBy(keySortOrders : _* )
115+ }
116+
117+ val (sortedPartitionKeys, splitIds) = withResource(sortedUniqueKeysTable) { _ =>
118+ val partitionKeys = toPartitionKeys(spec.partitionType(),
119+ partitionSparkType,
120+ sortedUniqueKeysTable)
121+
122+ val splitIdsCv = sortedKeyTableWithRowIdx.upperBound(
123+ sortedUniqueKeysTable,
124+ keySortOrders : _* )
125+
126+ val splitIds = withResource(splitIdsCv) { _ =>
127+ GpuColumnVector .toIntArray(splitIdsCv)
128+ }
129+
130+ (partitionKeys, splitIds)
131+ }
132+
133+ val rowIdxCol = sortedKeyTableWithRowIdx.getColumn(keyColNum).incRefCount()
134+ (sortedPartitionKeys, splitIds, rowIdxCol)
135+ }
136+
137+ withResource(rowIdxCol) { _ =>
138+ val inputTable = withResource(scb.getColumnarBatch()) { inputBatch =>
139+ GpuColumnVector .from(inputBatch)
140+ }
141+
142+ val sortedDataTable = withResource(inputTable) { _ =>
143+ inputTable.gather(rowIdxCol)
144+ }
145+
146+ val partitions = withResource(sortedDataTable) { _ =>
147+ sortedDataTable.contiguousSplit(splitIds : _* )
148+ }
149+
150+ (sortedPartitionKeys, partitions)
151+ }
152+ }
153+
154+ withResource(partitions) { _ =>
155+ partitionKeys.zip(partitions).map { case (partKey, partition) =>
156+ ColumnarBatchWithPartition (SpillableColumnarBatch (partition, sparkType, SpillPriorities
157+ .ACTIVE_BATCHING_PRIORITY ), partKey)
158+ }.toSeq
159+ }
160+
161+ }
162+
163+ private def getPartitionExpr (field : PartitionField )
164+ : GpuExpression = {
165+ val transform = field.transform()
166+ val inputIndex = fieldIndex(inputSchema, field.sourceId())
167+ val sparkField = dataSparkType.fields(inputIndex)
168+ val inputRefExpr = GpuBoundReference (inputIndex, sparkField.dataType,
169+ sparkField.nullable)(newExprId, s " input $inputIndex" )
170+
171+ transform.toString match {
172+ // bucket transform is like "bucket[16]"
173+ case s if s.startsWith(" bucket" ) =>
174+ val bucket = s.substring(" bucket[" .length, s.length - 1 ).toInt
175+ GpuBucketExpression (GpuLiteral .create(bucket), inputRefExpr)
176+ case other =>
177+ throw new IllegalArgumentException (s " Unsupported transform: $other" )
178+ }
179+ }
180+ }
181+
182+ case class ColumnarBatchWithPartition (batch : SpillableColumnarBatch , partition : StructLike ) extends
183+ AutoCloseable {
184+ override def close (): Unit = {
185+ batch.close()
186+ }
187+ }
188+
189+ object GpuIcebergPartitioner {
190+
191+ private def toPartitionKeys (icebergType : Types .StructType ,
192+ sparkType : StructType ,
193+ table : Table ): Array [SparkStructLike ] = {
194+ val numCols = table.getNumberOfColumns
195+ val numRows = toIntExact(table.getRowCount)
196+
197+ val hostColsArray = closeOnExcept(new Array [ColumnVector ](numCols)) { hostCols =>
198+ for (colIdx <- 0 until numCols) {
199+ hostCols(colIdx) = new RapidsHostColumnVector (sparkType.fields(colIdx).dataType,
200+ table.getColumn(colIdx).copyToHost())
201+ }
202+ hostCols
203+ }
204+
205+ withResource(new ColumnarBatch (hostColsArray, numRows)) { hostBatch =>
206+ hostBatch.rowIterator()
207+ .asScala
208+ .map(internalRow => {
209+ val row = new GenericRowWithSchema (internalRow.toSeq(sparkType).toArray, sparkType)
210+ new SparkStructLike (icebergType).wrap(row)
211+ }).toArray
212+ }
213+ }
214+ }
0 commit comments