Skip to content

Commit fee85b8

Browse files
Merge branch 'master' into elva/udfReturnRowType
2 parents 9585666 + fdcb049 commit fee85b8

File tree

12 files changed

+350
-53
lines changed

12 files changed

+350
-53
lines changed

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/SparkSessionTests.cs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
6+
using System.Linq;
57
using Microsoft.Spark.E2ETest.Utils;
68
using Microsoft.Spark.Sql;
79
using Microsoft.Spark.Sql.Catalog;
810
using Microsoft.Spark.Sql.Streaming;
11+
using Microsoft.Spark.Sql.Types;
912
using Xunit;
1013

1114
namespace Microsoft.Spark.E2ETest.IpcTests
@@ -65,5 +68,80 @@ public void TestSignaturesV2_4_X()
6568
{
6669
Assert.IsType<SparkSession>(SparkSession.Active());
6770
}
71+
72+
/// <summary>
73+
/// Test CreateDataFrame APIs.
74+
/// </summary>
75+
[Fact]
76+
public void TestCreateDataFrame()
77+
{
78+
// Calling CreateDataFrame with schema
79+
{
80+
var data = new List<GenericRow>();
81+
data.Add(new GenericRow(new object[] { "Alice", 20 }));
82+
data.Add(new GenericRow(new object[] { "Bob", 30 }));
83+
84+
var schema = new StructType(new List<StructField>()
85+
{
86+
new StructField("Name", new StringType()),
87+
new StructField("Age", new IntegerType())
88+
});
89+
DataFrame df = _spark.CreateDataFrame(data, schema);
90+
ValidateDataFrame(df, data.Select(a => a.Values), schema);
91+
}
92+
93+
// Calling CreateDataFrame(IEnumerable<string> _) without schema
94+
{
95+
var data = new List<string>(new string[] { "Alice", "Bob" });
96+
var schema = SchemaWithSingleColumn(new StringType());
97+
98+
DataFrame df = _spark.CreateDataFrame(data);
99+
ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
100+
}
101+
102+
// Calling CreateDataFrame(IEnumerable<int> _) without schema
103+
{
104+
var data = new List<int>(new int[] { 1, 2 });
105+
var schema = SchemaWithSingleColumn(new IntegerType());
106+
107+
DataFrame df = _spark.CreateDataFrame(data);
108+
ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
109+
}
110+
111+
// Calling CreateDataFrame(IEnumerable<double> _) without schema
112+
{
113+
var data = new List<double>(new double[] { 1.2, 2.3 });
114+
var schema = SchemaWithSingleColumn(new DoubleType());
115+
116+
DataFrame df = _spark.CreateDataFrame(data);
117+
ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
118+
}
119+
120+
// Calling CreateDataFrame(IEnumerable<bool> _) without schema
121+
{
122+
var data = new List<bool>(new bool[] { true, false });
123+
var schema = SchemaWithSingleColumn(new BooleanType());
124+
125+
DataFrame df = _spark.CreateDataFrame(data);
126+
ValidateDataFrame(df, data.Select(a => new object[] { a }), schema);
127+
}
128+
}
129+
130+
private void ValidateDataFrame(
131+
DataFrame actual,
132+
IEnumerable<object[]> expectedRows,
133+
StructType expectedSchema)
134+
{
135+
Assert.Equal(expectedSchema, actual.Schema());
136+
Assert.Equal(expectedRows, actual.Collect().Select(r => r.Values));
137+
}
138+
139+
/// <summary>
140+
/// Returns a single column schema of the given datatype.
141+
/// </summary>
142+
/// <param name="dataType">Datatype of the column</param>
143+
/// <returns>Schema as StructType</returns>
144+
private StructType SchemaWithSingleColumn(DataType dataType) =>
145+
new StructType(new[] { new StructField("_1", dataType) });
68146
}
69147
}

src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,26 @@ private Pickler CreatePickler()
139139
new RowPickler().Register();
140140
return new Pickler();
141141
}
142+
143+
[Fact]
144+
public void GenericRowTest()
145+
{
146+
var row = new GenericRow(new object[] { 1, "abc" });
147+
148+
// Validate Size().
149+
Assert.Equal(2, row.Size());
150+
151+
// Validate [] operator.
152+
Assert.Equal(1, row[0]);
153+
Assert.Equal("abc", row[1]);
154+
155+
// Validate Get*(int).
156+
Assert.Equal(1, row.Get(0));
157+
Assert.Equal("abc", row.Get(1));
158+
Assert.Equal(1, row.GetAs<int>(0));
159+
Assert.ThrowsAny<Exception>(() => row.GetAs<string>(0));
160+
Assert.Equal("abc", row.GetAs<string>(1));
161+
Assert.ThrowsAny<Exception>(() => row.GetAs<int>(1));
162+
}
142163
}
143164
}

src/csharp/Microsoft.Spark/Interop/Ipc/PayloadHelper.cs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Collections.Generic;
99
using System.IO;
1010
using System.Linq;
11+
using Microsoft.Spark.Sql;
1112

1213
namespace Microsoft.Spark.Interop.Ipc
1314
{
@@ -23,8 +24,9 @@ internal class PayloadHelper
2324
private static readonly byte[] s_doubleTypeId = new[] { (byte)'d' };
2425
private static readonly byte[] s_jvmObjectTypeId = new[] { (byte)'j' };
2526
private static readonly byte[] s_byteArrayTypeId = new[] { (byte)'r' };
26-
private static readonly byte[] s_intArrayTypeId = new[] { (byte)'l' };
27+
private static readonly byte[] s_arrayTypeId = new[] { (byte)'l' };
2728
private static readonly byte[] s_dictionaryTypeId = new[] { (byte)'e' };
29+
private static readonly byte[] s_rowArrTypeId = new[] { (byte)'R' };
2830

2931
private static readonly ConcurrentDictionary<Type, bool> s_isDictionaryTable =
3032
new ConcurrentDictionary<Type, bool>();
@@ -183,6 +185,22 @@ internal static void ConvertArgsToBytes(
183185
destination.Position = posAfterEnumerable;
184186
break;
185187

188+
case IEnumerable<GenericRow> argRowEnumerable:
189+
posBeforeEnumerable = destination.Position;
190+
destination.Position += sizeof(int);
191+
itemCount = 0;
192+
foreach (GenericRow r in argRowEnumerable)
193+
{
194+
++itemCount;
195+
SerDe.Write(destination, (int)r.Values.Length);
196+
ConvertArgsToBytes(destination, r.Values, true);
197+
}
198+
posAfterEnumerable = destination.Position;
199+
destination.Position = posBeforeEnumerable;
200+
SerDe.Write(destination, itemCount);
201+
destination.Position = posAfterEnumerable;
202+
break;
203+
186204
case var _ when IsDictionary(arg.GetType()):
187205
// Generic dictionary, but we don't have it strongly typed as
188206
// Dictionary<T,U>
@@ -271,7 +289,7 @@ internal static byte[] GetTypeId(Type type)
271289
typeof(IEnumerable<byte[]>).IsAssignableFrom(type) ||
272290
typeof(IEnumerable<string>).IsAssignableFrom(type))
273291
{
274-
return s_intArrayTypeId;
292+
return s_arrayTypeId;
275293
}
276294

277295
if (IsDictionary(type))
@@ -281,7 +299,12 @@ internal static byte[] GetTypeId(Type type)
281299

282300
if (typeof(IEnumerable<IJvmObjectReferenceProvider>).IsAssignableFrom(type))
283301
{
284-
return s_intArrayTypeId;
302+
return s_arrayTypeId;
303+
}
304+
305+
if (typeof(IEnumerable<GenericRow>).IsAssignableFrom(type))
306+
{
307+
return s_rowArrTypeId;
285308
}
286309
break;
287310
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
9+
namespace Microsoft.Spark.Sql
10+
{
11+
/// <summary>
12+
/// Represents a row object in RDD, equivalent to GenericRow in Spark.
13+
/// </summary>
14+
public sealed class GenericRow
15+
{
16+
/// <summary>
17+
/// Constructor for the GenericRow class.
18+
/// </summary>
19+
/// <param name="values">Column values for a row</param>
20+
public GenericRow(object[] values)
21+
{
22+
Values = values;
23+
}
24+
25+
/// <summary>
26+
/// Values representing this row.
27+
/// </summary>
28+
public object[] Values { get; }
29+
30+
/// <summary>
31+
/// Returns the number of columns in this row.
32+
/// </summary>
33+
/// <returns>Number of columns in this row</returns>
34+
public int Size() => Values.Length;
35+
36+
/// <summary>
37+
/// Returns the column value at the given index.
38+
/// </summary>
39+
/// <param name="index">Index to look up</param>
40+
/// <returns>A column value</returns>
41+
public object this[int index] => Get(index);
42+
43+
/// <summary>
44+
/// Returns the column value at the given index.
45+
/// </summary>
46+
/// <param name="index">Index to look up</param>
47+
/// <returns>A column value</returns>
48+
public object Get(int index)
49+
{
50+
if (index >= Size())
51+
{
52+
throw new IndexOutOfRangeException($"index ({index}) >= column counts ({Size()})");
53+
}
54+
else if (index < 0)
55+
{
56+
throw new IndexOutOfRangeException($"index ({index}) < 0)");
57+
}
58+
59+
return Values[index];
60+
}
61+
62+
/// <summary>
63+
/// Returns the string version of this row.
64+
/// </summary>
65+
/// <returns>String version of this row</returns>
66+
public override string ToString()
67+
{
68+
var cols = new List<string>();
69+
foreach (object item in Values)
70+
{
71+
cols.Add(item?.ToString() ?? string.Empty);
72+
}
73+
74+
return $"[{(string.Join(",", cols.ToArray()))}]";
75+
}
76+
77+
/// <summary>
78+
/// Returns the column value at the given index, as a type T.
79+
/// TODO: If the original type is "long" and its value can be
80+
/// fit into the "int", Pickler will serialize the value as int.
81+
/// Since the value is boxed, <see cref="GetAs{T}(int)"/> will throw an exception.
82+
/// </summary>
83+
/// <typeparam name="T">Type to convert to</typeparam>
84+
/// <param name="index">Index to look up</param>
85+
/// <returns>A column value as a type T</returns>
86+
public T GetAs<T>(int index) => (T)Get(index);
87+
88+
/// <summary>
89+
/// Checks if the given object is same as the current object.
90+
/// </summary>
91+
/// <param name="obj">Other object to compare against</param>
92+
/// <returns>True if the other object is equal.</returns>
93+
public override bool Equals(object obj) =>
94+
ReferenceEquals(this, obj) ||
95+
((obj is GenericRow row) && Values.SequenceEqual(row.Values));
96+
97+
/// <summary>
98+
/// Returns the hash code of the current object.
99+
/// </summary>
100+
/// <returns>The hash code of the current object</returns>
101+
public override int GetHashCode() => base.GetHashCode();
102+
}
103+
}

src/csharp/Microsoft.Spark/Sql/Row.cs

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@ namespace Microsoft.Spark.Sql
1414
/// </summary>
1515
public sealed class Row
1616
{
17+
private readonly GenericRow _genericRow;
18+
1719
/// <summary>
1820
/// Constructor for the Row class.
1921
/// </summary>
2022
/// <param name="values">Column values for a row</param>
2123
/// <param name="schema">Schema associated with a row</param>
2224
internal Row(object[] values, StructType schema)
2325
{
24-
Values = values;
26+
_genericRow = new GenericRow(values);
2527
Schema = schema;
2628

2729
var schemaColumnCount = Schema.Fields.Count;
@@ -42,13 +44,13 @@ internal Row(object[] values, StructType schema)
4244
/// <summary>
4345
/// Values representing this row.
4446
/// </summary>
45-
public object[] Values { get; }
47+
public object[] Values => _genericRow.Values;
4648

4749
/// <summary>
4850
/// Returns the number of columns in this row.
4951
/// </summary>
5052
/// <returns>Number of columns in this row</returns>
51-
public int Size() => Values.Length;
53+
public int Size() => _genericRow.Size();
5254

5355
/// <summary>
5456
/// Returns the column value at the given index.
@@ -62,20 +64,7 @@ internal Row(object[] values, StructType schema)
6264
/// </summary>
6365
/// <param name="index">Index to look up</param>
6466
/// <returns>A column value</returns>
65-
public object Get(int index)
66-
{
67-
if (index >= Size())
68-
{
69-
throw new IndexOutOfRangeException($"index ({index}) >= column counts ({Size()})");
70-
}
71-
else if (index < 0)
72-
{
73-
throw new IndexOutOfRangeException($"index ({index}) < 0)");
74-
}
75-
76-
return Values[index];
77-
}
78-
67+
public object Get(int index) => _genericRow.Get(index);
7968
/// <summary>
8069
/// Returns the column value whose column name is given.
8170
/// </summary>
@@ -88,16 +77,7 @@ public object Get(string columnName) =>
8877
/// Returns the string version of this row.
8978
/// </summary>
9079
/// <returns>String version of this row</returns>
91-
public override string ToString()
92-
{
93-
var cols = new List<string>();
94-
foreach (object item in Values)
95-
{
96-
cols.Add(item?.ToString() ?? string.Empty);
97-
}
98-
99-
return $"[{(string.Join(",", cols.ToArray()))}]";
100-
}
80+
public override string ToString() => _genericRow.ToString();
10181

10282
/// <summary>
10383
/// Returns the column value at the given index, as a type T.
@@ -126,26 +106,9 @@ public override string ToString()
126106
/// </summary>
127107
/// <param name="obj">Other object to compare against</param>
128108
/// <returns>True if the other object is equal.</returns>
129-
public override bool Equals(object obj)
130-
{
131-
if (obj is null)
132-
{
133-
return false;
134-
}
135-
136-
if (ReferenceEquals(this, obj))
137-
{
138-
return true;
139-
}
140-
141-
if (obj is Row otherRow)
142-
{
143-
return Values.SequenceEqual(otherRow.Values) &&
144-
Schema.Equals(otherRow.Schema);
145-
}
146-
147-
return false;
148-
}
109+
public override bool Equals(object obj) =>
110+
ReferenceEquals(this, obj) ||
111+
((obj is Row row) && _genericRow.Equals(row._genericRow)) && Schema.Equals(row.Schema);
149112

150113
/// <summary>
151114
/// Returns the hash code of the current object.

0 commit comments

Comments
 (0)