@@ -158,22 +158,57 @@ public void TestUdfWithRowType()
158
158
[ Fact ]
159
159
public void TestUdfWithReturnAsRowType ( )
160
160
{
161
- var schema = new StructType ( new [ ]
161
+ // Single GenericRow
162
+ var schema1 = new StructType ( new [ ]
162
163
{
163
164
new StructField ( "col1" , new IntegerType ( ) ) ,
164
165
new StructField ( "col2" , new StringType ( ) )
165
166
} ) ;
166
- Func < Column , Column > udf = Udf < string > (
167
- str => new GenericRow ( new object [ ] { 1 , "abc" } ) , schema ) ;
167
+ Func < Column , Column > udf1 = Udf < string > (
168
+ str => new GenericRow ( new object [ ] { 1 , "abc" } ) , schema1 ) ;
168
169
169
- Row [ ] rows = _df . Select ( udf ( _df [ "name" ] ) ) . Collect ( ) . ToArray ( ) ;
170
- Assert . Equal ( 3 , rows . Length ) ;
170
+ Row [ ] rows1 = _df . Select ( udf1 ( _df [ "name" ] ) ) . Collect ( ) . ToArray ( ) ;
171
+ Assert . Equal ( 3 , rows1 . Length ) ;
172
+
173
+ foreach ( Row row in rows1 )
174
+ {
175
+ Assert . Equal ( 2 , row . Size ( ) ) ;
176
+ Assert . Equal ( 1 , row . GetAs < int > ( "col1" ) ) ;
177
+ Assert . Equal ( "abc" , row . GetAs < string > ( "col2" ) ) ;
178
+ }
171
179
172
- foreach ( Row row in rows )
180
+ // Nested GenericRow
181
+ var subSchema1 = new StructType ( new [ ]
182
+ {
183
+ new StructField ( "subCol1" , new IntegerType ( ) )
184
+ } ) ;
185
+ var subSchema2 = new StructType ( new [ ]
186
+ {
187
+ new StructField ( "subCol2" , new StringType ( ) )
188
+ } ) ;
189
+ var schema2 = new StructType ( new [ ]
190
+ {
191
+ new StructField ( "col1" , subSchema1 ) ,
192
+ new StructField ( "col2" , subSchema2 )
193
+ } ) ;
194
+ Func < Column , Column > udf2 = Udf < string > (
195
+ str => new GenericRow (
196
+ new object [ ]
197
+ {
198
+ new GenericRow ( new object [ ] { 1 } ) ,
199
+ new GenericRow ( new object [ ] { "abc" } )
200
+ } ) , schema2 ) ;
201
+
202
+ Row [ ] rows2 = _df . Select ( udf2 ( _df [ "name" ] ) ) . Collect ( ) . ToArray ( ) ;
203
+ Assert . Equal ( 3 , rows2 . Length ) ;
204
+
205
+ foreach ( Row row in rows2 )
173
206
{
174
207
Assert . Equal ( 2 , row . Size ( ) ) ;
175
- Assert . Equal ( 1 , row [ 0 ] ) ;
176
- Assert . Equal ( "abc" , row [ 1 ] ) ;
208
+ Assert . IsType < Row > ( row . Get ( "col1" ) ) ;
209
+ Assert . IsType < Row > ( row . Get ( "col2" ) ) ;
210
+ Assert . Equal ( new Row ( new object [ ] { 1 } , subSchema1 ) , row . GetAs < Row > ( "col1" ) ) ;
211
+ Assert . Equal ( new Row ( new object [ ] { "abc" } , subSchema2 ) , row . GetAs < Row > ( "col2" ) ) ;
177
212
}
178
213
}
179
214
}
0 commit comments