|
15 | 15 | # specific language governing permissions and limitations
|
16 | 16 | # under the License.
|
17 | 17 |
|
18 |
| -import pyarrow as pa |
| 18 | +import pytz |
| 19 | +import hypothesis as h |
19 | 20 | import hypothesis.strategies as st
|
| 21 | +import hypothesis.extra.numpy as npst |
| 22 | +import hypothesis.extra.pytz as tzst |
| 23 | +import numpy as np |
| 24 | + |
| 25 | +import pyarrow as pa |
20 | 26 |
|
21 | 27 |
|
22 | 28 | # TODO(kszucs): alphanum_text, surrogate_text
|
|
69 | 75 | pa.time64('us'),
|
70 | 76 | pa.time64('ns')
|
71 | 77 | ])
|
72 |
| -timestamp_types = st.sampled_from([ |
73 |
| - pa.timestamp('s'), |
74 |
| - pa.timestamp('ms'), |
75 |
| - pa.timestamp('us'), |
76 |
| - pa.timestamp('ns') |
77 |
| -]) |
| 78 | +timestamp_types = st.builds( |
| 79 | + pa.timestamp, |
| 80 | + unit=st.sampled_from(['s', 'ms', 'us', 'ns']), |
| 81 | + tz=tzst.timezones() |
| 82 | +) |
78 | 83 | temporal_types = st.one_of(date_types, time_types, timestamp_types)
|
79 | 84 |
|
80 | 85 | primitive_types = st.one_of(
|
@@ -128,3 +133,60 @@ def schemas(type_strategy=primitive_types):
|
128 | 133 | all_types = st.one_of(primitive_types, complex_types(), nested_complex_types())
|
129 | 134 | all_fields = fields(all_types)
|
130 | 135 | all_schemas = schemas(all_types)
|
| 136 | + |
| 137 | + |
| 138 | +@st.composite |
| 139 | +def arrays(draw, type, size): |
| 140 | + if isinstance(type, st.SearchStrategy): |
| 141 | + type = draw(type) |
| 142 | + if isinstance(size, st.SearchStrategy): |
| 143 | + size = draw(size) |
| 144 | + |
| 145 | + if not isinstance(type, pa.DataType): |
| 146 | + raise TypeError('Type must be a pyarrow DataType') |
| 147 | + if not isinstance(size, int): |
| 148 | + raise TypeError('Size must be an integer') |
| 149 | + |
| 150 | + shape = (size,) |
| 151 | + |
| 152 | + if pa.types.is_list(type): |
| 153 | + offsets = draw(npst.arrays(np.uint8(), shape=shape)).cumsum() // 20 |
| 154 | + offsets = np.insert(offsets, 0, 0, axis=0) # prepend with zero |
| 155 | + values = draw(arrays(type.value_type, size=int(offsets.sum()))) |
| 156 | + return pa.ListArray.from_arrays(offsets, values) |
| 157 | + |
| 158 | + if pa.types.is_struct(type): |
| 159 | + h.assume(len(type) > 0) # TODO issue pa.struct([]) |
| 160 | + names, child_arrays = [], [] |
| 161 | + for field in type: |
| 162 | + names.append(field.name) |
| 163 | + child_arrays.append(draw(arrays(field.type, size=size))) |
| 164 | + return pa.StructArray.from_arrays(child_arrays, names=names) |
| 165 | + |
| 166 | + if (pa.types.is_boolean(type) or pa.types.is_integer(type) or |
| 167 | + pa.types.is_floating(type)): |
| 168 | + values = npst.arrays(type.to_pandas_dtype(), shape=(size,)) |
| 169 | + return pa.array(draw(values), type=type) |
| 170 | + |
| 171 | + if pa.types.is_null(type): |
| 172 | + value = st.none() |
| 173 | + elif pa.types.is_time(type): |
| 174 | + value = st.times() |
| 175 | + elif pa.types.is_date(type): |
| 176 | + value = st.dates() |
| 177 | + elif pa.types.is_timestamp(type): |
| 178 | + tz = pytz.timezone(type.tz) if type.tz is not None else None |
| 179 | + value = st.datetimes(timezones=st.just(tz)) |
| 180 | + elif pa.types.is_binary(type): |
| 181 | + value = st.binary() |
| 182 | + elif pa.types.is_string(type): |
| 183 | + value = st.text() |
| 184 | + elif pa.types.is_decimal(type): |
| 185 | + # FIXME(kszucs): properly limit the precision |
| 186 | + value = st.decimals(places=type.scale, allow_infinity=False) |
| 187 | + type = None # We let arrow infer it from the values |
| 188 | + else: |
| 189 | + raise NotImplementedError(type) |
| 190 | + |
| 191 | + values = st.lists(value, min_size=size, max_size=size) |
| 192 | + return pa.array(draw(values), type=type) |
0 commit comments