@@ -111,16 +111,16 @@ def complex_types(inner_strategy=primitive_types):
111
111
return list_types (inner_strategy ) | struct_types (inner_strategy )
112
112
113
113
114
- def nested_list_types (item_strategy = primitive_types ):
115
- return st .recursive (item_strategy , list_types )
114
+ def nested_list_types (item_strategy = primitive_types , max_leaves = 3 ):
115
+ return st .recursive (item_strategy , list_types , max_leaves = max_leaves )
116
116
117
117
118
- def nested_struct_types (item_strategy = primitive_types ):
119
- return st .recursive (item_strategy , struct_types )
118
+ def nested_struct_types (item_strategy = primitive_types , max_leaves = 3 ):
119
+ return st .recursive (item_strategy , struct_types , max_leaves = max_leaves )
120
120
121
121
122
- def nested_complex_types (inner_strategy = primitive_types ):
123
- return st .recursive (inner_strategy , complex_types )
122
+ def nested_complex_types (inner_strategy = primitive_types , max_leaves = 3 ):
123
+ return st .recursive (inner_strategy , complex_types , max_leaves = max_leaves )
124
124
125
125
126
126
def schemas (type_strategy = primitive_types , max_fields = None ):
@@ -156,14 +156,13 @@ def arrays(draw, type, size=None):
156
156
shape = (size ,)
157
157
158
158
if pa .types .is_list (type ):
159
- # TODO(kszucs) limit the depth
160
159
offsets = draw (npst .arrays (np .uint8 (), shape = shape )).cumsum () // 20
161
160
offsets = np .insert (offsets , 0 , 0 , axis = 0 ) # prepend with zero
162
161
values = draw (arrays (type .value_type , size = int (offsets .sum ())))
163
162
return pa .ListArray .from_arrays (offsets , values )
164
163
165
164
if pa .types .is_struct (type ):
166
- h .assume (len (type ) > 0 ) # TODO(kszucs): create issue -> pa.struct([])
165
+ h .assume (len (type ) > 0 )
167
166
names , child_arrays = [], []
168
167
for field in type :
169
168
names .append (field .name )
@@ -190,12 +189,11 @@ def arrays(draw, type, size=None):
190
189
value = st .binary ()
191
190
elif pa .types .is_string (type ):
192
191
value = st .text ()
193
- # elif pa.types.is_decimal(type):
194
- # # TODO(kszucs): properly limit the precision
195
- # value = st.decimals(places=type.scale, allow_infinity=False)
196
- # type = None # We let arrow infer it from the values
192
+ elif pa .types .is_decimal (type ):
193
+ # TODO(kszucs): properly limit the precision
194
+ # value = st.decimals(places=type.scale, allow_infinity=False)
195
+ h . reject ()
197
196
else :
198
- h .assume (not pa .types .is_decimal (type ))
199
197
raise NotImplementedError (type )
200
198
201
199
values = st .lists (value , min_size = size , max_size = size )
@@ -234,10 +232,27 @@ def record_batches(draw, type, rows=None, max_fields=None):
234
232
235
233
schema = draw (schemas (type , max_fields = max_fields ))
236
234
children = [draw (arrays (field .type , size = rows )) for field in schema ]
235
+ # TODO(kszucs): the names and schame arguments are not consistent with
236
+ # Table.from_array's arguments
237
237
return pa .RecordBatch .from_arrays (children , names = schema )
238
238
239
239
240
+ @st .composite
241
+ def tables (draw , type , rows = None , max_fields = None ):
242
+ if isinstance (rows , st .SearchStrategy ):
243
+ rows = draw (rows )
244
+ elif rows is None :
245
+ rows = draw (_default_array_sizes )
246
+ elif not isinstance (rows , int ):
247
+ raise TypeError ('Rows must be an integer' )
248
+
249
+ schema = draw (schemas (type , max_fields = max_fields ))
250
+ children = [draw (arrays (field .type , size = rows )) for field in schema ]
251
+ return pa .Table .from_arrays (children , schema = schema )
252
+
253
+
240
254
all_arrays = arrays (all_types )
241
255
all_chunked_arrays = chunked_arrays (all_types )
242
256
all_columns = columns (all_types )
243
257
all_record_batches = record_batches (all_types )
258
+ all_tables = tables (all_types )
0 commit comments