Skip to content

Commit f061e6e

Browse files
committed
hypothesis array strategy
1 parent 0696eb5 commit f061e6e

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

python/pyarrow/tests/strategies.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pyarrow as pa
18+
import pytz
19+
import hypothesis as h
1920
import hypothesis.strategies as st
21+
import hypothesis.extra.numpy as npst
22+
import hypothesis.extra.pytz as tzst
23+
import numpy as np
24+
25+
import pyarrow as pa
2026

2127

2228
# TODO(kszucs): alphanum_text, surrogate_text
@@ -69,12 +75,11 @@
6975
pa.time64('us'),
7076
pa.time64('ns')
7177
])
72-
timestamp_types = st.sampled_from([
73-
pa.timestamp('s'),
74-
pa.timestamp('ms'),
75-
pa.timestamp('us'),
76-
pa.timestamp('ns')
77-
])
78+
timestamp_types = st.builds(
79+
pa.timestamp,
80+
unit=st.sampled_from(['s', 'ms', 'us', 'ns']),
81+
tz=tzst.timezones()
82+
)
7883
temporal_types = st.one_of(date_types, time_types, timestamp_types)
7984

8085
primitive_types = st.one_of(
@@ -128,3 +133,60 @@ def schemas(type_strategy=primitive_types):
128133
all_types = st.one_of(primitive_types, complex_types(), nested_complex_types())
129134
all_fields = fields(all_types)
130135
all_schemas = schemas(all_types)
136+
137+
138+
@st.composite
139+
def arrays(draw, type, size):
140+
if isinstance(type, st.SearchStrategy):
141+
type = draw(type)
142+
if isinstance(size, st.SearchStrategy):
143+
size = draw(size)
144+
145+
if not isinstance(type, pa.DataType):
146+
raise TypeError('Type must be a pyarrow DataType')
147+
if not isinstance(size, int):
148+
raise TypeError('Size must be an integer')
149+
150+
shape = (size,)
151+
152+
if pa.types.is_list(type):
153+
offsets = draw(npst.arrays(np.uint8(), shape=shape)).cumsum() // 20
154+
offsets = np.insert(offsets, 0, 0, axis=0) # prepend with zero
155+
values = draw(arrays(type.value_type, size=int(offsets.sum())))
156+
return pa.ListArray.from_arrays(offsets, values)
157+
158+
if pa.types.is_struct(type):
159+
h.assume(len(type) > 0) # TODO issue pa.struct([])
160+
names, child_arrays = [], []
161+
for field in type:
162+
names.append(field.name)
163+
child_arrays.append(draw(arrays(field.type, size=size)))
164+
return pa.StructArray.from_arrays(child_arrays, names=names)
165+
166+
if (pa.types.is_boolean(type) or pa.types.is_integer(type) or
167+
pa.types.is_floating(type)):
168+
values = npst.arrays(type.to_pandas_dtype(), shape=(size,))
169+
return pa.array(draw(values), type=type)
170+
171+
if pa.types.is_null(type):
172+
value = st.none()
173+
elif pa.types.is_time(type):
174+
value = st.times()
175+
elif pa.types.is_date(type):
176+
value = st.dates()
177+
elif pa.types.is_timestamp(type):
178+
tz = pytz.timezone(type.tz) if type.tz is not None else None
179+
value = st.datetimes(timezones=st.just(tz))
180+
elif pa.types.is_binary(type):
181+
value = st.binary()
182+
elif pa.types.is_string(type):
183+
value = st.text()
184+
elif pa.types.is_decimal(type):
185+
# FIXME(kszucs): properly limit the precision
186+
value = st.decimals(places=type.scale, allow_infinity=False)
187+
type = None # We let arrow infer it from the values
188+
else:
189+
raise NotImplementedError(type)
190+
191+
values = st.lists(value, min_size=size, max_size=size)
192+
return pa.array(draw(values), type=type)

python/pyarrow/tests/test_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import collections
2020
import datetime
21+
import hypothesis as h
22+
import hypothesis.strategies as st
2123
import pickle
2224
import pytest
2325
import struct
@@ -32,6 +34,7 @@
3234
pickle5 = None
3335

3436
import pyarrow as pa
37+
import pyarrow.tests.strategies as past
3538
from pyarrow.pandas_compat import get_logical_type
3639

3740

@@ -802,6 +805,18 @@ def test_array_pickle(data, typ):
802805
assert array.equals(result)
803806

804807

808+
@h.given(
809+
past.arrays(
810+
past.all_types,
811+
size=st.integers(min_value=0, max_value=10)
812+
)
813+
)
814+
def test_pickling(arr):
815+
data = pickle.dumps(arr)
816+
restored = pickle.loads(data)
817+
assert arr.equals(restored)
818+
819+
805820
@pickle_test_parametrize
806821
def test_array_pickle5(data, typ):
807822
# Test zero-copy pickling with protocol 5 (PEP 574)

0 commit comments

Comments
 (0)