Skip to content

Commit cedefd7

Browse files
committed
add pickling for GenericRow
1 parent fee85b8 commit cedefd7

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfComplexTypesTests.cs

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,22 +158,57 @@ public void TestUdfWithRowType()
158158
[Fact]
159159
public void TestUdfWithReturnAsRowType()
160160
{
161-
var schema = new StructType(new[]
161+
// Single GenericRow
162+
var schema1 = new StructType(new[]
162163
{
163164
new StructField("col1", new IntegerType()),
164165
new StructField("col2", new StringType())
165166
});
166-
Func<Column, Column> udf = Udf<string>(
167-
str => new GenericRow(new object[] { 1, "abc" }), schema);
167+
Func<Column, Column> udf1 = Udf<string>(
168+
str => new GenericRow(new object[] { 1, "abc" }), schema1);
168169

169-
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
170-
Assert.Equal(3, rows.Length);
170+
Row[] rows1 = _df.Select(udf1(_df["name"])).Collect().ToArray();
171+
Assert.Equal(3, rows1.Length);
172+
173+
foreach (Row row in rows1)
174+
{
175+
Assert.Equal(2, row.Size());
176+
Assert.Equal(1, row.GetAs<int>("col1"));
177+
Assert.Equal("abc", row.GetAs<string>("col2"));
178+
}
171179

172-
foreach (Row row in rows)
180+
// Nested GenericRow
181+
var subSchema1 = new StructType(new[]
182+
{
183+
new StructField("subCol1", new IntegerType())
184+
});
185+
var subSchema2 = new StructType(new[]
186+
{
187+
new StructField("subCol2", new StringType())
188+
});
189+
var schema2 = new StructType(new[]
190+
{
191+
new StructField("col1", subSchema1),
192+
new StructField("col2", subSchema2)
193+
});
194+
Func<Column, Column> udf2 = Udf<string>(
195+
str => new GenericRow(
196+
new object[]
197+
{
198+
new GenericRow(new object[] { 1 }),
199+
new GenericRow(new object[] { "abc" })
200+
}), schema2);
201+
202+
Row[] rows2 = _df.Select(udf2(_df["name"])).Collect().ToArray();
203+
Assert.Equal(3, rows2.Length);
204+
205+
foreach (Row row in rows2)
173206
{
174207
Assert.Equal(2, row.Size());
175-
Assert.Equal(1, row[0]);
176-
Assert.Equal("abc", row[1]);
208+
Assert.IsType<Row>(row.Get("col1"));
209+
Assert.IsType<Row>(row.Get("col2"));
210+
Assert.Equal(new Row(new object[] { 1 }, subSchema1), row.GetAs<Row>("col1"));
211+
Assert.Equal(new Row(new object[] { "abc" }, subSchema2), row.GetAs<Row>("col2"));
177212
}
178213
}
179214
}

src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,33 @@ private void WriteOutput(Stream stream, IEnumerable<object> rows, int sizeHint)
165165
if (s_outputBuffer == null)
166166
s_outputBuffer = new byte[sizeHint];
167167

168+
// Cast GenericRow to object if needed.
169+
var newRows = new List<object>();
168170
if (rows.FirstOrDefault() is GenericRow)
169171
{
170-
rows = rows.Select(r => (object)(r as GenericRow).Values).AsEnumerable();
172+
for (int i = 0; i < rows.Count(); ++i)
173+
{
174+
object[] cols = (rows.ElementAt(i) as GenericRow).Values;
175+
if (cols.FirstOrDefault() is GenericRow)
176+
{
177+
object[] newCols = new object[cols.Length];
178+
for (int j = 0; j < cols.Length; ++j)
179+
{
180+
newCols[j] = (cols[j] as GenericRow).Values;
181+
}
182+
newRows.Add(newCols);
183+
}
184+
else
185+
{
186+
newRows.Add(cols);
187+
}
188+
}
171189
}
172190

173191
Pickler pickler = s_pickler ?? (s_pickler = new Pickler(false));
174-
pickler.dumps(rows, ref s_outputBuffer, out int bytesWritten);
192+
pickler.dumps(
193+
newRows.Count() == 0 ? rows : newRows,
194+
ref s_outputBuffer, out int bytesWritten);
175195

176196
if (bytesWritten <= 0)
177197
{

0 commit comments

Comments
 (0)