Skip to content

Commit 6064831

Browse files
Adding support for Pivot with values (#642)
1 parent c80b99f commit 6064831

File tree

6 files changed

+71
-0
lines changed

6 files changed

+71
-0
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/DataFrameTests.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,10 +680,15 @@ public void TestSignaturesV2_4_X()
680680

681681
{
682682
RelationalGroupedDataset df = _df.GroupBy("name");
683+
var values = new List<object> { 19, "twenty" };
683684

684685
Assert.IsType<RelationalGroupedDataset>(df.Pivot("age"));
685686

686687
Assert.IsType<RelationalGroupedDataset>(df.Pivot(Col("age")));
688+
689+
Assert.IsType<RelationalGroupedDataset>(df.Pivot("age", values));
690+
691+
Assert.IsType<RelationalGroupedDataset>(df.Pivot(Col("age"), values));
687692
}
688693
}
689694
}

src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ internal class PayloadHelper
3131
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
3232
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
3333
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
34+
private static readonly byte[] s_objectArrTypeId = new[] { (byte)'O' };
3435

3536
private static readonly ConcurrentDictionary<Type, bool> s_isDictionaryTable =
3637
new ConcurrentDictionary<Type, bool>();
@@ -218,6 +219,26 @@ internal static void ConvertArgsToBytes(
218219
destination.Position = posAfterEnumerable;
219220
break;
220221

222+
case IEnumerable<object> argObjectEnumerable:
223+
posBeforeEnumerable = destination.Position;
224+
destination.Position += sizeof(int);
225+
itemCount = 0;
226+
if (convertArgs == null)
227+
{
228+
convertArgs = new object[1];
229+
}
230+
foreach (object o in argObjectEnumerable)
231+
{
232+
++itemCount;
233+
convertArgs[0] = o;
234+
ConvertArgsToBytes(destination, convertArgs, true);
235+
}
236+
posAfterEnumerable = destination.Position;
237+
destination.Position = posBeforeEnumerable;
238+
SerDe.Write(destination, itemCount);
239+
destination.Position = posAfterEnumerable;
240+
break;
241+
221242
case var _ when IsDictionary(arg.GetType()):
222243
// Generic dictionary, but we don't have it strongly typed as
223244
// Dictionary<T,U>
@@ -333,6 +354,11 @@ internal static byte[] GetTypeId(Type type)
333354
return s_rowArrTypeId;
334355
}
335356

357+
if (typeof(IEnumerable<object>).IsAssignableFrom(type))
358+
{
359+
return s_objectArrTypeId;
360+
}
361+
336362
if (typeof(Date).IsAssignableFrom(type))
337363
{
338364
return s_dateTypeId;

src/csharp/Microsoft.Spark/Sql/RelationalGroupedDataset.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ public RelationalGroupedDataset Pivot(string pivotColumn) =>
9494
new RelationalGroupedDataset(
9595
(JvmObjectReference)_jvmObject.Invoke("pivot", pivotColumn), _dataFrame);
9696

97+
/// <summary>
98+
/// Pivots a column of the current DataFrame and performs the specified aggregation.
99+
/// </summary>
100+
/// <param name="pivotColumn">Name of the column to pivot of type string</param>
101+
/// <param name="values">List of values that will be translated to columns in the
102+
/// output DataFrame.</param>
103+
/// <returns>New RelationalGroupedDataset object with pivot applied</returns>
104+
public RelationalGroupedDataset Pivot(string pivotColumn, IEnumerable<object> values) =>
105+
new RelationalGroupedDataset(
106+
(JvmObjectReference)_jvmObject.Invoke("pivot", pivotColumn, values), _dataFrame);
107+
97108
/// <summary>
98109
/// Pivots a column of the current DataFrame and performs the specified aggregation.
99110
/// </summary>
@@ -103,6 +114,17 @@ public RelationalGroupedDataset Pivot(Column pivotColumn) =>
103114
new RelationalGroupedDataset(
104115
(JvmObjectReference)_jvmObject.Invoke("pivot", pivotColumn), _dataFrame);
105116

117+
/// <summary>
118+
/// Pivots a column of the current DataFrame and performs the specified aggregation.
119+
/// </summary>
120+
/// <param name="pivotColumn">The column to pivot of type <see cref="Column"/></param>
121+
/// <param name="values">List of values that will be translated to columns in the
122+
/// output DataFrame.</param>
123+
/// <returns>New RelationalGroupedDataset object with pivot applied</returns>
124+
public RelationalGroupedDataset Pivot(Column pivotColumn, IEnumerable<object> values) =>
125+
new RelationalGroupedDataset(
126+
(JvmObjectReference)_jvmObject.Invoke("pivot", pivotColumn, values), _dataFrame);
127+
106128
internal DataFrame Apply(StructType returnType, Func<FxDataFrame, FxDataFrame> func)
107129
{
108130
DataFrameGroupedMapWorkerFunction.ExecuteDelegate wrapper =

src/scala/microsoft-spark-2.3.x/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ object SerDe {
4242
case 't' => readTime(dis)
4343
case 'j' => JVMObjectTracker.getObject(readString(dis))
4444
case 'R' => readRowArr(dis)
45+
case 'O' => readObjectArr(dis)
4546
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
4647
}
4748
}
@@ -138,6 +139,11 @@ object SerDe {
138139
(0 until len).map(_ => readRow(in)).toList.asJava
139140
}
140141

142+
def readObjectArr(in: DataInputStream): Seq[Any] = {
143+
val len = readInt(in)
144+
(0 until len).map(_ => readObject(in))
145+
}
146+
141147
def readList(dis: DataInputStream): Array[_] = {
142148
val arrType = readObjectType(dis)
143149
arrType match {

src/scala/microsoft-spark-2.4.x/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ object SerDe {
4242
case 't' => readTime(dis)
4343
case 'j' => JVMObjectTracker.getObject(readString(dis))
4444
case 'R' => readRowArr(dis)
45+
case 'O' => readObjectArr(dis)
4546
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
4647
}
4748
}
@@ -138,6 +139,11 @@ object SerDe {
138139
(0 until len).map(_ => readRow(in)).toList.asJava
139140
}
140141

142+
def readObjectArr(in: DataInputStream): Seq[Any] = {
143+
val len = readInt(in)
144+
(0 until len).map(_ => readObject(in))
145+
}
146+
141147
def readList(dis: DataInputStream): Array[_] = {
142148
val arrType = readObjectType(dis)
143149
arrType match {

src/scala/microsoft-spark-3.0.x/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ object SerDe {
4242
case 't' => readTime(dis)
4343
case 'j' => JVMObjectTracker.getObject(readString(dis))
4444
case 'R' => readRowArr(dis)
45+
case 'O' => readObjectArr(dis)
4546
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
4647
}
4748
}
@@ -138,6 +139,11 @@ object SerDe {
138139
(0 until len).map(_ => readRow(in)).toList.asJava
139140
}
140141

142+
def readObjectArr(in: DataInputStream): Seq[Any] = {
143+
val len = readInt(in)
144+
(0 until len).map(_ => readObject(in))
145+
}
146+
141147
def readList(dis: DataInputStream): Array[_] = {
142148
val arrType = readObjectType(dis)
143149
arrType match {

0 commit comments

Comments
 (0)