Skip to content

Commit 27b6ac7

Browse files
committed
strategies for chunked_arrays, columns, record batches; test the strategies themselves
1 parent dbe1491 commit 27b6ac7

File tree

2 files changed

+112
-12
lines changed

2 files changed

+112
-12
lines changed

python/pyarrow/tests/strategies.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ def nested_complex_types(inner_strategy=primitive_types):
123123
return st.recursive(inner_strategy, complex_types)
124124

125125

126-
def schemas(type_strategy=primitive_types):
127-
return st.builds(pa.schema, st.lists(fields(type_strategy)))
126+
def schemas(type_strategy=primitive_types, max_fields=None):
127+
children = st.lists(fields(type_strategy), max_size=max_fields)
128+
return st.builds(pa.schema, children)
128129

129130

130131
complex_schemas = schemas(complex_types())
@@ -135,32 +136,40 @@ def schemas(type_strategy=primitive_types):
135136
all_schemas = schemas(all_types)
136137

137138

139+
_default_array_sizes = st.integers(min_value=0, max_value=20)
140+
141+
138142
@st.composite
139-
def arrays(draw, type, size):
143+
def arrays(draw, type, size=None):
140144
if isinstance(type, st.SearchStrategy):
141145
type = draw(type)
146+
elif not isinstance(type, pa.DataType):
147+
raise TypeError('Type must be a pyarrow DataType')
148+
142149
if isinstance(size, st.SearchStrategy):
143150
size = draw(size)
144-
145-
if not isinstance(type, pa.DataType):
146-
raise TypeError('Type must be a pyarrow DataType')
147-
if not isinstance(size, int):
151+
elif size is None:
152+
size = draw(_default_array_sizes)
153+
elif not isinstance(size, int):
148154
raise TypeError('Size must be an integer')
149155

150156
shape = (size,)
151157

152158
if pa.types.is_list(type):
159+
# TODO(kszucs) limit the depth
153160
offsets = draw(npst.arrays(np.uint8(), shape=shape)).cumsum() // 20
154161
offsets = np.insert(offsets, 0, 0, axis=0) # prepend with zero
155162
values = draw(arrays(type.value_type, size=int(offsets.sum())))
156163
return pa.ListArray.from_arrays(offsets, values)
157164

158165
if pa.types.is_struct(type):
159-
h.assume(len(type) > 0) # TODO issue pa.struct([])
166+
h.assume(len(type) > 0) # TODO(kszucs): create issue -> pa.struct([])
160167
names, child_arrays = [], []
161168
for field in type:
162169
names.append(field.name)
163170
child_arrays.append(draw(arrays(field.type, size=size)))
171+
# fields' metadata are lost here, because from_arrays doesn't accept
172+
# a fields argumentum, only names
164173
return pa.StructArray.from_arrays(child_arrays, names=names)
165174

166175
if (pa.types.is_boolean(type) or pa.types.is_integer(type) or
@@ -181,12 +190,54 @@ def arrays(draw, type, size):
181190
value = st.binary()
182191
elif pa.types.is_string(type):
183192
value = st.text()
184-
elif pa.types.is_decimal(type):
185-
# FIXME(kszucs): properly limit the precision
186-
value = st.decimals(places=type.scale, allow_infinity=False)
187-
type = None # We let arrow infer it from the values
193+
# elif pa.types.is_decimal(type):
194+
# # TODO(kszucs): properly limit the precision
195+
# value = st.decimals(places=type.scale, allow_infinity=False)
196+
# type = None # We let arrow infer it from the values
188197
else:
198+
h.assume(not pa.types.is_decimal(type))
189199
raise NotImplementedError(type)
190200

191201
values = st.lists(value, min_size=size, max_size=size)
192202
return pa.array(draw(values), type=type)
203+
204+
205+
@st.composite
206+
def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None):
207+
if isinstance(type, st.SearchStrategy):
208+
type = draw(type)
209+
210+
# TODO(kszucs): remove it, field metadata is not kept
211+
h.assume(not pa.types.is_struct(type))
212+
213+
chunk = arrays(type, size=chunk_size)
214+
chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks)
215+
216+
return pa.chunked_array(draw(chunks), type=type)
217+
218+
219+
def columns(type, min_chunks=0, max_chunks=None, chunk_size=None):
220+
chunked_array = chunked_arrays(type, chunk_size=chunk_size,
221+
min_chunks=min_chunks,
222+
max_chunks=max_chunks)
223+
return st.builds(pa.column, st.text(), chunked_array)
224+
225+
226+
@st.composite
227+
def record_batches(draw, type, rows=None, max_fields=None):
228+
if isinstance(rows, st.SearchStrategy):
229+
rows = draw(rows)
230+
elif rows is None:
231+
rows = draw(_default_array_sizes)
232+
elif not isinstance(rows, int):
233+
raise TypeError('Rows must be an integer')
234+
235+
schema = draw(schemas(type, max_fields=max_fields))
236+
children = [draw(arrays(field.type, size=rows)) for field in schema]
237+
return pa.RecordBatch.from_arrays(children, names=schema)
238+
239+
240+
all_arrays = arrays(all_types)
241+
all_chunked_arrays = chunked_arrays(all_types)
242+
all_columns = columns(all_types)
243+
all_record_batches = record_batches(all_types)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import hypothesis as h
2+
import hypothesis.strategies as st
3+
4+
import pyarrow as pa
5+
import pyarrow.tests.strategies as past
6+
7+
8+
@h.given(past.all_types)
9+
def test_types(ty):
10+
assert isinstance(ty, pa.lib.DataType)
11+
12+
13+
@h.given(past.all_fields)
14+
def test_fields(field):
15+
assert isinstance(field, pa.lib.Field)
16+
17+
18+
@h.given(past.all_schemas)
19+
def test_schemas(schema):
20+
assert isinstance(schema, pa.lib.Schema)
21+
22+
23+
@h.given(past.all_arrays)
24+
def test_arrays(array):
25+
assert isinstance(array, pa.lib.Array)
26+
27+
28+
@h.given(past.all_chunked_arrays)
29+
def test_chunked_arrays(chunked_array):
30+
assert isinstance(chunked_array, pa.lib.ChunkedArray)
31+
32+
33+
@h.given(past.all_columns)
34+
def test_columns(column):
35+
assert isinstance(column, pa.lib.Column)
36+
37+
38+
@h.given(past.all_record_batches)
39+
def test_record_batches(record_bath):
40+
assert isinstance(record_bath, pa.lib.RecordBatch)
41+
42+
43+
############################################################
44+
45+
46+
@h.given(st.text(), past.all_arrays | past.all_chunked_arrays)
47+
def test_column_factory(name, arr):
48+
column = pa.column(name, arr)
49+
assert isinstance(column, pa.Column)

0 commit comments

Comments
 (0)