Skip to content

Follow up on UDF that takes in and returns Row #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/csharp/Microsoft.Spark.E2ETest/Resources/people.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{"name":"Michael", "ids":[1], "info":{"city":"Burdwan", "state":"Paschimbanga"}}
{"name":"Andy", "age":30, "ids":[3,5], "info":{"city":"Los Angeles", "state":"California"}}
{"name":"Justin", "age":19, "ids":[2,4], "info":{"city":"Seattle"}}
{"name":"Michael", "ids":[1], "info1":{"city":"Burdwan"}, "info2":{"state":"Paschimbanga"}, "info3":{"company":{"job":"Developer"}}}"
{"name":"Andy", "age":30, "ids":[3,5], "info1":{"city":"Los Angeles"}, "info2":{"state":"California"}, "info3":{"company":{"job":"Developer"}}}
{"name":"Justin", "age":19, "ids":[2,4], "info1":{"city":"Seattle"}, "info2":{"state":"Washington"}, "info3":{"company":{"job":"Developer"}}}
107 changes: 87 additions & 20 deletions src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfComplexTypesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,59 @@ public void TestUdfWithReturnAsMapType()
[Fact]
public void TestUdfWithRowType()
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
string city = row.GetAs<string>("city");
string state = row.GetAs<string>("state");
return $"{city},{state}";
});
// Single Row
{
Func<Column, Column> udf = Udf<Row, string>(
(row) => row.GetAs<string>("city"));

Row[] rows = _df.Select(udf(_df["info"])).Collect().ToArray();
Assert.Equal(3, rows.Length);
Row[] rows = _df.Select(udf(_df["info1"])).Collect().ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] { "Burdwan,Paschimbanga", "Los Angeles,California", "Seattle," };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
var expected = new[] { "Burdwan", "Los Angeles", "Seattle" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}

// Multiple Rows
{
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
(row1, row2, str) =>
{
string city = row1.GetAs<string>("city");
string state = row2.GetAs<string>("state");
return $"{str}:{city},{state}";
});

Row[] rows = _df
.Select(udf(_df["info1"], _df["info2"], _df["name"]))
.Collect()
.ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] {
"Michael:Burdwan,Paschimbanga",
"Andy:Los Angeles,California",
"Justin:Seattle,Washington" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}

// Nested Row
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
Row outerCol = row.GetAs<Row>("company");
return outerCol.GetAs<string>("job");
});

Row[] rows = _df.Select(udf(_df["info3"])).Collect().ToArray();
Assert.Equal(3, rows.Length);

var expected = new[] { "Developer", "Developer", "Developer" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
}

/// <summary>
Expand All @@ -168,14 +207,40 @@ public void TestUdfWithReturnAsRowType()
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 1, "abc" }), schema);

Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Equal(1, row.Size());
Row outerCol = row.GetAs<Row>("col");
Assert.Equal(2, outerCol.Size());
Assert.Equal(1, outerCol.GetAs<int>("col1"));
Assert.Equal("abc", outerCol.GetAs<string>("col2"));
}
}

// Generic row is a part of top-level column.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType())
});
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 111 }), schema);

Column nameCol = _df["name"];
Row[] rows = _df.Select(udf(nameCol).As("col"), nameCol).Collect().ToArray();
Assert.Equal(3, rows.Length);

foreach (Row row in rows)
{
Assert.Equal(2, row.Size());
Assert.Equal(1, row.GetAs<int>("col1"));
Assert.Equal("abc", row.GetAs<string>("col2"));
Row col1 = row.GetAs<Row>("col");
Assert.Equal(1, col1.Size());
Assert.Equal(111, col1.GetAs<int>("col1"));

string col2 = row.GetAs<string>("name");
Assert.NotEmpty(col2);
}
}

Expand Down Expand Up @@ -211,21 +276,23 @@ public void TestUdfWithReturnAsRowType()
}),
schema);

Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);

foreach (Row row in rows)
{
Assert.Equal(3, row.Size());
Assert.Equal(1, row.GetAs<int>("col1"));
Assert.Equal(1, row.Size());
Row outerCol = row.GetAs<Row>("col");
Assert.Equal(3, outerCol.Size());
Assert.Equal(1, outerCol.GetAs<int>("col1"));
Assert.Equal(
new Row(new object[] { 1 }, subSchema1),
row.GetAs<Row>("col2"));
outerCol.GetAs<Row>("col2"));
Assert.Equal(
new Row(
new object[] { "abc", new Row(new object[] { 10 }, subSchema1) },
subSchema2),
row.GetAs<Row>("col3"));
outerCol.GetAs<Row>("col3"));
}
}
}
Expand Down
11 changes: 10 additions & 1 deletion src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,17 @@ protected override CommandExecutorStat ExecuteCore(

for (int i = 0; i < inputRows.Length; ++i)
{
object row = inputRows[i];
// The following can happen if an UDF takes Row object(s).
// The JVM Spark side sends a Row object that wraps all the columns used
// in the UDF, thus, it is normalized below (the extra layer is removed).
if (row is RowConstructor rowConstructor)
{
row = rowConstructor.GetRow().Values;
}

// Split id is not used for SQL UDFs, so 0 is passed.
outputRows.Add(commandRunner.Run(0, inputRows[i]));
outputRows.Add(commandRunner.Run(0, row));
}

// The initial (estimated) buffer size for pickling rows is set to the size of
Expand Down
20 changes: 1 addition & 19 deletions src/csharp/Microsoft.Spark/Sql/RowCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.Interop.Ipc;
Expand Down Expand Up @@ -34,24 +33,7 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)

foreach (object unpickled in unpickledObjects)
{
// Unpickled object can be either a RowConstructor object (not materialized),
// or a Row object (materialized). Refer to RowConstruct.construct() to see how
// Row objects are unpickled.
switch (unpickled)
{
case RowConstructor rc:
yield return rc.GetRow();
break;

case object[] objs when objs.Length == 1 && (objs[0] is Row row):
yield return row;
break;

default:
throw new NotSupportedException(
string.Format("Unpickle type {0} is not supported",
unpickled.GetType()));
}
yield return (unpickled as RowConstructor).GetRow();
}
}
}
Expand Down
9 changes: 0 additions & 9 deletions src/csharp/Microsoft.Spark/Sql/RowConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ public object construct(object[] args)
s_schemaCache = new Dictionary<string, StructType>();
}

// When a row is ready to be materialized, then construct() is called
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a breaking change, but @suhsteve is updating the worker version in his PR: https://github.com/dotnet/spark/pull/387/files

// on the RowConstructor which represents the row.
if ((args.Length == 1) && (args[0] is RowConstructor rowConstructor))
{
// Construct the Row and return args containing the Row.
args[0] = rowConstructor.GetRow();
return args;
}

// Return a new RowConstructor where the args either represent the
// schema or the row data. The parent becomes important when calling
// GetRow() on the RowConstructor containing the row data.
Expand Down