Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Sync
run: uv sync --all-extras
run: scripts/install
- name: Pre-commit
run: uv run pre-commit run --all-files
- name: Lint
run: scripts/lint
- name: Test
run: uv run pytest tests -v
run: scripts/test
- name: Check docs
run: uv run mkdocs build --strict
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ Install via `pip` or `conda`:

## Development

Get [uv](https://docs.astral.sh/uv/getting-started/installation/), then:
1. Create a Python virtual environment
1. Get [uv](https://docs.astral.sh/uv/getting-started/installation/):
1. Get [cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html):

Then run:

```shell
git clone [email protected]:stac-utils/stac-geoparquet.git
cd stac-geoparquet
uv sync
scripts/install
uv run pre-commit install
uv run pytest
scripts/test
scripts/lint
```
5 changes: 5 additions & 0 deletions scripts/install
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env sh

set -e

uv sync --all-extras
5 changes: 5 additions & 0 deletions scripts/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env sh

set -e

uv run pytest tests -v
72 changes: 55 additions & 17 deletions stac_geoparquet/pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import itertools
import logging
import textwrap
from typing import Any
from typing import Any, Literal

import dateutil.tz
import fsspec
import orjson
import pandas as pd
import pyarrow.fs
import pypgstac.db
Expand All @@ -23,6 +24,7 @@

logger = logging.getLogger(__name__)

EXPORT_FORMAT = Literal["geoparquet", "ndjson"]

def _pairwise(
iterable: collections.abc.Iterable,
Expand Down Expand Up @@ -148,32 +150,31 @@ def export_partition(
storage_options: dict[str, Any] | None = None,
rewrite: bool = False,
skip_empty_partitions: bool = False,
format: EXPORT_FORMAT = "geoparquet",
) -> str | None:
storage_options = storage_options or {}
az_fs = fsspec.filesystem(output_protocol, **storage_options)
if az_fs.exists(output_path) and not rewrite:
logger.debug("Path %s already exists.", output_path)
return output_path

db = pypgstac.db.PgstacDB(conninfo)
with db:
assert db.connection is not None
db.connection.execute("set statement_timeout = 300000;")
# logger.debug("Reading base item")
# TODO: proper escaping
base_item = db.query_one(
f"select * from collection_base_item('{self.collection_id}');"
)
records = list(db.query(query))
base_item, records = _enumerate_db_items(self.collection_id, conninfo, query)

if skip_empty_partitions and len(records) == 0:
logger.debug("No records found for query %s.", query)
return None

items = self.make_pgstac_items(records, base_item) # type: ignore[arg-type]
df = to_geodataframe(items)
filesystem = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(az_fs))
df.to_parquet(output_path, index=False, filesystem=filesystem)

logger.debug("Exporting %d items as %s to %s", len(items), format, output_path)
if format == "geoparquet":
df = to_geodataframe(items)
filesystem = pyarrow.fs.PyFileSystem(pyarrow.fs.FSSpecHandler(az_fs))
df.to_parquet(output_path, index=False, filesystem=filesystem)
elif format == "ndjson":
_write_ndjson(output_path, az_fs, items)
else:
raise ValueError(f"Unsupported export format: {format}")
return output_path

def export_partition_for_endpoints(
Expand All @@ -187,6 +188,7 @@ def export_partition_for_endpoints(
total: int | None = None,
rewrite: bool = False,
skip_empty_partitions: bool = False,
format: EXPORT_FORMAT = "geoparquet",
) -> str | None:
"""
Export results for a pair of endpoints.
Expand All @@ -205,7 +207,9 @@ def export_partition_for_endpoints(
+ f"and datetime >= '{a.isoformat()}' and datetime < '{b.isoformat()}'"
)

partition_path = _build_output_path(output_path, part_number, total, a, b)
partition_path = _build_output_path(
output_path, part_number, total, a, b, format=format
)
return self.export_partition(
conninfo,
query,
Expand All @@ -214,6 +218,7 @@ def export_partition_for_endpoints(
storage_options=storage_options,
rewrite=rewrite,
skip_empty_partitions=skip_empty_partitions,
format=format,
)

def export_collection(
Expand All @@ -224,6 +229,7 @@ def export_collection(
storage_options: dict[str, Any],
rewrite: bool = False,
skip_empty_partitions: bool = False,
format: EXPORT_FORMAT = "geoparquet",
) -> list[str | None]:
base_query = textwrap.dedent(
f"""\
Expand All @@ -246,6 +252,7 @@ def export_collection(
output_path,
storage_options=storage_options,
rewrite=rewrite,
format=format,
)
]

Expand All @@ -269,6 +276,7 @@ def export_collection(
skip_empty_partitions=skip_empty_partitions,
part_number=i,
total=total,
format=format,
)
)

Expand Down Expand Up @@ -340,20 +348,50 @@ def _build_output_path(
total: int | None,
start_datetime: datetime.datetime,
end_datetime: datetime.datetime,
format: EXPORT_FORMAT = "geoparquet",
) -> str:
a, b = start_datetime, end_datetime
base_output_path = base_output_path.rstrip("/")
file_extensions = {
"geoparquet": "parquet",
"ndjson": "ndjson",
}

if part_number is not None and total is not None:
output_path = (
f"{base_output_path}/part-{part_number:0{len(str(total * 10))}}_"
f"{a.isoformat()}_{b.isoformat()}.parquet"
f"{a.isoformat()}_{b.isoformat()}.{file_extensions[format]}"
)
else:
token = hashlib.md5(
"".join([a.isoformat(), b.isoformat()]).encode()
).hexdigest()
output_path = (
f"{base_output_path}/part-{token}_{a.isoformat()}_{b.isoformat()}.parquet"
f"{base_output_path}/part-{token}_{a.isoformat()}_{b.isoformat()}.{file_extensions[format]}"
)
return output_path

def _enumerate_db_items(
collection_id: str,
conninfo: str,
query: str) -> tuple[Any, list[Any]]:
db = pypgstac.db.PgstacDB(conninfo)
with db:
assert db.connection is not None
db.connection.execute("set statement_timeout = 300000;")
# logger.debug("Reading base item")
# TODO: proper escaping
base_item = db.query_one(
f"select * from collection_base_item('{collection_id}');"
)
records = list(db.query(query))
return base_item, records

def _write_ndjson(
output_path: str,
fs: fsspec.AbstractFileSystem,
items: list[dict]) -> None:
with fs.open(output_path, "wb") as f:
for item in items:
f.write(orjson.dumps(item))
f.write(b"\n")
53 changes: 43 additions & 10 deletions tests/test_pgstac_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys

import dateutil
import fsspec
import pandas as pd
import pystac
import pytest
Expand All @@ -14,6 +15,19 @@

HERE = pathlib.Path(__file__).parent

@pytest.fixture
def sentinel2_collection_config() -> stac_geoparquet.pgstac_reader.CollectionConfig:
return stac_geoparquet.pgstac_reader.CollectionConfig(
collection_id="sentinel-2-l2a",
partition_frequency=None,
stac_api="https://planetarycomputer.microsoft.com/api/stac/v1",
should_inject_dynamic_properties=True,
render_config="assets=visual&asset_bidx=visual%7C1%2C2%2C3&nodata=0&format=png",
)

@pytest.fixture
def sentinel2_record():
return json.loads(HERE.joinpath("record_sentinel2_l2a.json").read_text())

@pytest.mark.vcr
@pytest.mark.skipif(
Expand Down Expand Up @@ -124,20 +138,15 @@ def test_naip_item():
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="vcr tests require python3.10 or higher"
)
def test_sentinel2_l2a():
record = json.loads(HERE.joinpath("record_sentinel2_l2a.json").read_text())
def test_sentinel2_l2a(
sentinel2_collection_config: stac_geoparquet.pgstac_reader.CollectionConfig,
sentinel2_record) -> None:
record = sentinel2_record
base_item = json.loads(HERE.joinpath("base_sentinel2_l2a.json").read_text())
record[3] = dateutil.parser.parse(record[3])
record[4] = dateutil.parser.parse(record[4])

config = stac_geoparquet.pgstac_reader.CollectionConfig(
collection_id="sentinel-2-l2a",
partition_frequency=None,
stac_api="https://planetarycomputer.microsoft.com/api/stac/v1",
should_inject_dynamic_properties=True,
render_config="assets=visual&asset_bidx=visual%7C1%2C2%2C3&nodata=0&format=png",
)
result = pystac.read_dict(config.make_pgstac_items([record], base_item)[0])
result = pystac.read_dict(sentinel2_collection_config.make_pgstac_items([record], base_item)[0])
expected = pystac.read_file(
"https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2A_MSIL2A_20150704T101006_R022_T35XQA_20210411T133707" # noqa: E501
)
Expand Down Expand Up @@ -199,3 +208,27 @@ def test_build_output_path(part_number, total, start_datetime, end_datetime, exp
base_output_path, part_number, total, start_datetime, end_datetime
)
assert result == expected

def test_write_ndjson(
tmp_path,
sentinel2_collection_config: stac_geoparquet.pgstac_reader.CollectionConfig,
sentinel2_record) -> None:
record = sentinel2_record
base_item = json.loads(HERE.joinpath("base_sentinel2_l2a.json").read_text())

items = sentinel2_collection_config.make_pgstac_items(
[record, record], base_item)
fs = fsspec.filesystem("file")
stac_geoparquet.pgstac_reader._write_ndjson(
tmp_path / "test.ndjson",
fs,
items
)
# check that the file has 2 lines
with fs.open(tmp_path / "test.ndjson") as f:
lines = f.readlines()
assert len(lines) == 2
# check that the first line is a valid json
json.loads(lines[0])
# check that the second line is a valid json
json.loads(lines[1])
Loading