Skip to content

Commit fd51c9c

Browse files
authored
Add batched tests (#73)
1 parent 5af1294 commit fd51c9c

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

stac_geoparquet/arrow/_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def stac_table_to_ndjson(
146146
) -> None:
147147
"""Write STAC Arrow to a newline-delimited JSON file.
148148
149+
!!! note
150+
This function _appends_ to the JSON file at `dest`; it does not overwrite any
151+
existing data.
152+
149153
Args:
150154
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
151155
RecordBatchReader, or any other Arrow stream object exposed through the

tests/test_arrow.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import json
23
from io import BytesIO
34
from pathlib import Path
@@ -7,6 +8,7 @@
78
import pytest
89

910
from stac_geoparquet.arrow import (
11+
DEFAULT_JSON_CHUNK_SIZE,
1012
parse_stac_items_to_arrow,
1113
parse_stac_ndjson_to_arrow,
1214
stac_table_to_items,
@@ -33,38 +35,54 @@
3335
"us-census",
3436
]
3537

38+
CHUNK_SIZES = [2, DEFAULT_JSON_CHUNK_SIZE]
3639

37-
@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS)
38-
def test_round_trip_read_write(collection_id: str):
40+
41+
@pytest.mark.parametrize(
42+
"collection_id,chunk_size", itertools.product(TEST_COLLECTIONS, CHUNK_SIZES)
43+
)
44+
def test_round_trip_read_write(collection_id: str, chunk_size: int):
3945
with open(HERE / "data" / f"{collection_id}-pc.json") as f:
4046
items = json.load(f)
4147

42-
table = pa.Table.from_batches(parse_stac_items_to_arrow(items))
48+
table = parse_stac_items_to_arrow(items, chunk_size=chunk_size).read_all()
4349
items_result = list(stac_table_to_items(table))
4450

4551
for result, expected in zip(items_result, items):
4652
assert_json_value_equal(result, expected, precision=0)
4753

4854

49-
@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS)
50-
def test_round_trip_write_read_ndjson(collection_id: str, tmp_path: Path):
55+
@pytest.mark.parametrize(
56+
"collection_id,chunk_size", itertools.product(TEST_COLLECTIONS, CHUNK_SIZES)
57+
)
58+
def test_round_trip_write_read_ndjson(
59+
collection_id: str, chunk_size: int, tmp_path: Path
60+
):
5161
# First load into a STAC-GeoParquet table
5262
path = HERE / "data" / f"{collection_id}-pc.json"
53-
table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path))
63+
table = parse_stac_ndjson_to_arrow(path, chunk_size=chunk_size).read_all()
5464

5565
# Then write to disk
5666
stac_table_to_ndjson(table, tmp_path / "tmp.ndjson")
5767

58-
# Then read back and assert tables match
59-
table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(tmp_path / "tmp.ndjson"))
68+
with open(path) as f:
69+
orig_json = json.load(f)
70+
71+
rt_json = []
72+
with open(tmp_path / "tmp.ndjson") as f:
73+
for line in f:
74+
rt_json.append(json.loads(line))
75+
76+
# Then read back and assert JSON data matches
77+
assert_json_value_equal(orig_json, rt_json, precision=0)
6078

6179

6280
def test_table_contains_geoarrow_metadata():
6381
collection_id = "naip"
6482
with open(HERE / "data" / f"{collection_id}-pc.json") as f:
6583
items = json.load(f)
6684

67-
table = pa.Table.from_batches(parse_stac_items_to_arrow(items))
85+
table = parse_stac_items_to_arrow(items).read_all()
6886
field_meta = table.schema.field("geometry").metadata
6987
assert field_meta[b"ARROW:extension:name"] == b"geoarrow.wkb"
7088
assert json.loads(field_meta[b"ARROW:extension:metadata"])["crs"]["id"] == {
@@ -107,7 +125,7 @@ def test_to_parquet_two_geometry_columns():
107125
with open(HERE / "data" / "3dep-lidar-copc-pc.json") as f:
108126
items = json.load(f)
109127

110-
table = pa.Table.from_batches(parse_stac_items_to_arrow(items))
128+
table = parse_stac_items_to_arrow(items).read_all()
111129
with BytesIO() as bio:
112130
to_parquet(table, bio)
113131
bio.seek(0)

0 commit comments

Comments
 (0)