Skip to content

Commit a8db985

Browse files
Follow up on UDF that takes in and returns Row (#406)
1 parent 96d0fed commit a8db985

File tree

5 files changed

+101
-52
lines changed

5 files changed

+101
-52
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
{"name":"Michael", "ids":[1], "info":{"city":"Burdwan", "state":"Paschimbanga"}}
2-
{"name":"Andy", "age":30, "ids":[3,5], "info":{"city":"Los Angeles", "state":"California"}}
3-
{"name":"Justin", "age":19, "ids":[2,4], "info":{"city":"Seattle"}}
1+
{"name":"Michael", "ids":[1], "info1":{"city":"Burdwan"}, "info2":{"state":"Paschimbanga"}, "info3":{"company":{"job":"Developer"}}}"
2+
{"name":"Andy", "age":30, "ids":[3,5], "info1":{"city":"Los Angeles"}, "info2":{"state":"California"}, "info3":{"company":{"job":"Developer"}}}
3+
{"name":"Justin", "age":19, "ids":[2,4], "info1":{"city":"Seattle"}, "info2":{"state":"Washington"}, "info3":{"company":{"job":"Developer"}}}

src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfComplexTypesTests.cs

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,59 @@ public void TestUdfWithReturnAsMapType()
136136
[Fact]
137137
public void TestUdfWithRowType()
138138
{
139-
Func<Column, Column> udf = Udf<Row, string>(
140-
(row) =>
141-
{
142-
string city = row.GetAs<string>("city");
143-
string state = row.GetAs<string>("state");
144-
return $"{city},{state}";
145-
});
139+
// Single Row
140+
{
141+
Func<Column, Column> udf = Udf<Row, string>(
142+
(row) => row.GetAs<string>("city"));
146143

147-
Row[] rows = _df.Select(udf(_df["info"])).Collect().ToArray();
148-
Assert.Equal(3, rows.Length);
144+
Row[] rows = _df.Select(udf(_df["info1"])).Collect().ToArray();
145+
Assert.Equal(3, rows.Length);
149146

150-
var expected = new[] { "Burdwan,Paschimbanga", "Los Angeles,California", "Seattle," };
151-
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
152-
Assert.Equal(expected, actual);
147+
var expected = new[] { "Burdwan", "Los Angeles", "Seattle" };
148+
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
149+
Assert.Equal(expected, actual);
150+
}
151+
152+
// Multiple Rows
153+
{
154+
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
155+
(row1, row2, str) =>
156+
{
157+
string city = row1.GetAs<string>("city");
158+
string state = row2.GetAs<string>("state");
159+
return $"{str}:{city},{state}";
160+
});
161+
162+
Row[] rows = _df
163+
.Select(udf(_df["info1"], _df["info2"], _df["name"]))
164+
.Collect()
165+
.ToArray();
166+
Assert.Equal(3, rows.Length);
167+
168+
var expected = new[] {
169+
"Michael:Burdwan,Paschimbanga",
170+
"Andy:Los Angeles,California",
171+
"Justin:Seattle,Washington" };
172+
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
173+
Assert.Equal(expected, actual);
174+
}
175+
176+
// Nested Row
177+
{
178+
Func<Column, Column> udf = Udf<Row, string>(
179+
(row) =>
180+
{
181+
Row outerCol = row.GetAs<Row>("company");
182+
return outerCol.GetAs<string>("job");
183+
});
184+
185+
Row[] rows = _df.Select(udf(_df["info3"])).Collect().ToArray();
186+
Assert.Equal(3, rows.Length);
187+
188+
var expected = new[] { "Developer", "Developer", "Developer" };
189+
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
190+
Assert.Equal(expected, actual);
191+
}
153192
}
154193

155194
/// <summary>
@@ -168,14 +207,40 @@ public void TestUdfWithReturnAsRowType()
168207
Func<Column, Column> udf = Udf<string>(
169208
str => new GenericRow(new object[] { 1, "abc" }), schema);
170209

171-
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
210+
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
211+
Assert.Equal(3, rows.Length);
212+
foreach (Row row in rows)
213+
{
214+
Assert.Equal(1, row.Size());
215+
Row outerCol = row.GetAs<Row>("col");
216+
Assert.Equal(2, outerCol.Size());
217+
Assert.Equal(1, outerCol.GetAs<int>("col1"));
218+
Assert.Equal("abc", outerCol.GetAs<string>("col2"));
219+
}
220+
}
221+
222+
// Generic row is a part of top-level column.
223+
{
224+
var schema = new StructType(new[]
225+
{
226+
new StructField("col1", new IntegerType())
227+
});
228+
Func<Column, Column> udf = Udf<string>(
229+
str => new GenericRow(new object[] { 111 }), schema);
230+
231+
Column nameCol = _df["name"];
232+
Row[] rows = _df.Select(udf(nameCol).As("col"), nameCol).Collect().ToArray();
172233
Assert.Equal(3, rows.Length);
173234

174235
foreach (Row row in rows)
175236
{
176237
Assert.Equal(2, row.Size());
177-
Assert.Equal(1, row.GetAs<int>("col1"));
178-
Assert.Equal("abc", row.GetAs<string>("col2"));
238+
Row col1 = row.GetAs<Row>("col");
239+
Assert.Equal(1, col1.Size());
240+
Assert.Equal(111, col1.GetAs<int>("col1"));
241+
242+
string col2 = row.GetAs<string>("name");
243+
Assert.NotEmpty(col2);
179244
}
180245
}
181246

@@ -211,21 +276,23 @@ public void TestUdfWithReturnAsRowType()
211276
}),
212277
schema);
213278

214-
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
279+
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
215280
Assert.Equal(3, rows.Length);
216281

217282
foreach (Row row in rows)
218283
{
219-
Assert.Equal(3, row.Size());
220-
Assert.Equal(1, row.GetAs<int>("col1"));
284+
Assert.Equal(1, row.Size());
285+
Row outerCol = row.GetAs<Row>("col");
286+
Assert.Equal(3, outerCol.Size());
287+
Assert.Equal(1, outerCol.GetAs<int>("col1"));
221288
Assert.Equal(
222289
new Row(new object[] { 1 }, subSchema1),
223-
row.GetAs<Row>("col2"));
290+
outerCol.GetAs<Row>("col2"));
224291
Assert.Equal(
225292
new Row(
226293
new object[] { "abc", new Row(new object[] { 10 }, subSchema1) },
227294
subSchema2),
228-
row.GetAs<Row>("col3"));
295+
outerCol.GetAs<Row>("col3"));
229296
}
230297
}
231298
}

src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,17 @@ protected override CommandExecutorStat ExecuteCore(
135135

136136
for (int i = 0; i < inputRows.Length; ++i)
137137
{
138+
object row = inputRows[i];
139+
// The following can happen if an UDF takes Row object(s).
140+
// The JVM Spark side sends a Row object that wraps all the columns used
141+
// in the UDF, thus, it is normalized below (the extra layer is removed).
142+
if (row is RowConstructor rowConstructor)
143+
{
144+
row = rowConstructor.GetRow().Values;
145+
}
146+
138147
// Split id is not used for SQL UDFs, so 0 is passed.
139-
outputRows.Add(commandRunner.Run(0, inputRows[i]));
148+
outputRows.Add(commandRunner.Run(0, row));
140149
}
141150

142151
// The initial (estimated) buffer size for pickling rows is set to the size of

src/csharp/Microsoft.Spark/Sql/RowCollector.cs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
76
using System.IO;
87
using Microsoft.Spark.Interop.Ipc;
@@ -34,24 +33,7 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)
3433

3534
foreach (object unpickled in unpickledObjects)
3635
{
37-
// Unpickled object can be either a RowConstructor object (not materialized),
38-
// or a Row object (materialized). Refer to RowConstruct.construct() to see how
39-
// Row objects are unpickled.
40-
switch (unpickled)
41-
{
42-
case RowConstructor rc:
43-
yield return rc.GetRow();
44-
break;
45-
46-
case object[] objs when objs.Length == 1 && (objs[0] is Row row):
47-
yield return row;
48-
break;
49-
50-
default:
51-
throw new NotSupportedException(
52-
string.Format("Unpickle type {0} is not supported",
53-
unpickled.GetType()));
54-
}
36+
yield return (unpickled as RowConstructor).GetRow();
5537
}
5638
}
5739
}

src/csharp/Microsoft.Spark/Sql/RowConstructor.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,6 @@ public object construct(object[] args)
6565
s_schemaCache = new Dictionary<string, StructType>();
6666
}
6767

68-
// When a row is ready to be materialized, then construct() is called
69-
// on the RowConstructor which represents the row.
70-
if ((args.Length == 1) && (args[0] is RowConstructor rowConstructor))
71-
{
72-
// Construct the Row and return args containing the Row.
73-
args[0] = rowConstructor.GetRow();
74-
return args;
75-
}
76-
7768
// Return a new RowConstructor where the args either represent the
7869
// schema or the row data. The parent becomes important when calling
7970
// GetRow() on the RowConstructor containing the row data.

0 commit comments

Comments
 (0)