@@ -123,8 +123,9 @@ def nested_complex_types(inner_strategy=primitive_types):
123
123
return st .recursive (inner_strategy , complex_types )
124
124
125
125
126
- def schemas (type_strategy = primitive_types ):
127
- return st .builds (pa .schema , st .lists (fields (type_strategy )))
126
+ def schemas (type_strategy = primitive_types , max_fields = None ):
127
+ children = st .lists (fields (type_strategy ), max_size = max_fields )
128
+ return st .builds (pa .schema , children )
128
129
129
130
130
131
complex_schemas = schemas (complex_types ())
@@ -135,32 +136,40 @@ def schemas(type_strategy=primitive_types):
135
136
all_schemas = schemas (all_types )
136
137
137
138
139
+ _default_array_sizes = st .integers (min_value = 0 , max_value = 20 )
140
+
141
+
138
142
@st .composite
139
- def arrays (draw , type , size ):
143
+ def arrays (draw , type , size = None ):
140
144
if isinstance (type , st .SearchStrategy ):
141
145
type = draw (type )
146
+ elif not isinstance (type , pa .DataType ):
147
+ raise TypeError ('Type must be a pyarrow DataType' )
148
+
142
149
if isinstance (size , st .SearchStrategy ):
143
150
size = draw (size )
144
-
145
- if not isinstance (type , pa .DataType ):
146
- raise TypeError ('Type must be a pyarrow DataType' )
147
- if not isinstance (size , int ):
151
+ elif size is None :
152
+ size = draw (_default_array_sizes )
153
+ elif not isinstance (size , int ):
148
154
raise TypeError ('Size must be an integer' )
149
155
150
156
shape = (size ,)
151
157
152
158
if pa .types .is_list (type ):
159
+ # TODO(kszucs) limit the depth
153
160
offsets = draw (npst .arrays (np .uint8 (), shape = shape )).cumsum () // 20
154
161
offsets = np .insert (offsets , 0 , 0 , axis = 0 ) # prepend with zero
155
162
values = draw (arrays (type .value_type , size = int (offsets .sum ())))
156
163
return pa .ListArray .from_arrays (offsets , values )
157
164
158
165
if pa .types .is_struct (type ):
159
- h .assume (len (type ) > 0 ) # TODO issue pa.struct([])
166
+ h .assume (len (type ) > 0 ) # TODO(kszucs): create issue -> pa.struct([])
160
167
names , child_arrays = [], []
161
168
for field in type :
162
169
names .append (field .name )
163
170
child_arrays .append (draw (arrays (field .type , size = size )))
171
+ # fields' metadata are lost here, because from_arrays doesn't accept
172
+ # a fields argumentum, only names
164
173
return pa .StructArray .from_arrays (child_arrays , names = names )
165
174
166
175
if (pa .types .is_boolean (type ) or pa .types .is_integer (type ) or
@@ -181,12 +190,54 @@ def arrays(draw, type, size):
181
190
value = st .binary ()
182
191
elif pa .types .is_string (type ):
183
192
value = st .text ()
184
- elif pa .types .is_decimal (type ):
185
- # FIXME (kszucs): properly limit the precision
186
- value = st .decimals (places = type .scale , allow_infinity = False )
187
- type = None # We let arrow infer it from the values
193
+ # elif pa.types.is_decimal(type):
194
+ # # TODO (kszucs): properly limit the precision
195
+ # value = st.decimals(places=type.scale, allow_infinity=False)
196
+ # type = None # We let arrow infer it from the values
188
197
else :
198
+ h .assume (not pa .types .is_decimal (type ))
189
199
raise NotImplementedError (type )
190
200
191
201
values = st .lists (value , min_size = size , max_size = size )
192
202
return pa .array (draw (values ), type = type )
203
+
204
+
205
+ @st .composite
206
+ def chunked_arrays (draw , type , min_chunks = 0 , max_chunks = None , chunk_size = None ):
207
+ if isinstance (type , st .SearchStrategy ):
208
+ type = draw (type )
209
+
210
+ # TODO(kszucs): remove it, field metadata is not kept
211
+ h .assume (not pa .types .is_struct (type ))
212
+
213
+ chunk = arrays (type , size = chunk_size )
214
+ chunks = st .lists (chunk , min_size = min_chunks , max_size = max_chunks )
215
+
216
+ return pa .chunked_array (draw (chunks ), type = type )
217
+
218
+
219
+ def columns (type , min_chunks = 0 , max_chunks = None , chunk_size = None ):
220
+ chunked_array = chunked_arrays (type , chunk_size = chunk_size ,
221
+ min_chunks = min_chunks ,
222
+ max_chunks = max_chunks )
223
+ return st .builds (pa .column , st .text (), chunked_array )
224
+
225
+
226
+ @st .composite
227
+ def record_batches (draw , type , rows = None , max_fields = None ):
228
+ if isinstance (rows , st .SearchStrategy ):
229
+ rows = draw (rows )
230
+ elif rows is None :
231
+ rows = draw (_default_array_sizes )
232
+ elif not isinstance (rows , int ):
233
+ raise TypeError ('Rows must be an integer' )
234
+
235
+ schema = draw (schemas (type , max_fields = max_fields ))
236
+ children = [draw (arrays (field .type , size = rows )) for field in schema ]
237
+ return pa .RecordBatch .from_arrays (children , names = schema )
238
+
239
+
240
+ all_arrays = arrays (all_types )
241
+ all_chunked_arrays = chunked_arrays (all_types )
242
+ all_columns = columns (all_types )
243
+ all_record_batches = record_batches (all_types )
0 commit comments