Skip to content

Commit 93c718e

Browse files
committed
Fixing linting and adding eo bands to pass test
1 parent 6470e6a commit 93c718e

File tree

5 files changed

+60
-11
lines changed

5 files changed

+60
-11
lines changed

examples/collection.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@
7575
{
7676
"href": "item_multi_io.json",
7777
"rel": "item"
78+
},
79+
{
80+
"href": "torchgeo.json",
81+
"rel": "item"
7882
}
7983
]
8084
}

examples/torchgeo.json

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,32 @@
113113
"type": "application/octet-stream; application=pytorch",
114114
"title": "U-Net (efficientnet-b3) weights trained on FTW dataset",
115115
"description": "A U-Net segmentation model with efficientnet-b3 encoder trained on FTW imagery with 2-class labels. Weights are non-commercial.",
116+
"eo:bands": [
117+
{
118+
"name": "B4"
119+
},
120+
{
121+
"name": "B3"
122+
},
123+
{
124+
"name": "B2"
125+
},
126+
{
127+
"name": "B8A"
128+
},
129+
{
130+
"name": "B4"
131+
},
132+
{
133+
"name": "B3"
134+
},
135+
{
136+
"name": "B2"
137+
},
138+
{
139+
"name": "B8A"
140+
}
141+
],
116142
"roles": [
117143
"mlm:model",
118144
"mlm:weights",
@@ -147,6 +173,7 @@
147173
58.21
148174
],
149175
"stac_extensions": [
150-
"https://crim-ca.github.io/mlm-extension/v1.3.0/schema.json"
176+
"https://crim-ca.github.io/mlm-extension/v1.3.0/schema.json",
177+
"https://stac-extensions.github.io/eo/v1.1.0/schema.json"
151178
]
152179
}

stac_model/examples.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
from dateutil.parser import parse as parse_dt
66
from pystac.extensions.eo import Band, EOExtension
77
from pystac.extensions.file import FileExtension
8+
from torchgeo.models import Unet_Weights, unet
89

910
from stac_model.base import ProcessingExpression
1011
from stac_model.input import InputStructure, MLMStatistic, ModelInput
1112
from stac_model.output import MLMClassification, ModelOutput, ModelResult
1213
from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties
1314

14-
from torchgeo.models import unet, Unet_Weights
15-
1615

1716
def eurosat_resnet() -> ItemMLModelExtension:
1817
input_struct = InputStructure(
@@ -242,7 +241,9 @@ def eurosat_resnet() -> ItemMLModelExtension:
242241

243242
def unet_mlm() -> ItemMLModelExtension:
244243
"""
245-
Example of a UNet model using pytorchgeo SENTINEL2_2CLASS_NC_FTW default weights with Machine Learning Model Extension metadata.
244+
Example of a UNet model using PyTorchGeo SENTINEL2_2CLASS_NC_FTW default weights.
245+
246+
Returns an ItemMLModelExtension with Machine Learning Model Extension metadata.
246247
"""
247248
weights = Unet_Weights.SENTINEL2_2CLASS_NC_FTW
248249
model = unet(weights=weights)

stac_model/schema.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818
import pystac
19+
import torch.nn as nn
1920
from pydantic import ConfigDict, Field
2021
from pydantic.fields import FieldInfo
2122
from pystac.extensions.base import (
@@ -29,8 +30,6 @@
2930
from stac_model.output import ModelOutput
3031
from stac_model.runtime import Runtime
3132

32-
import torch.nn as nn
33-
3433
T = TypeVar(
3534
"T",
3635
pystac.Collection,

stac_model/utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from typing import Optional, cast
2+
3+
import torch.nn as nn
14
from pystac import Asset, Item, Link
5+
from pystac.extensions.eo import Band, EOExtension
6+
27
from stac_model.input import InputStructure, ModelInput
38
from stac_model.output import MLMClassification, ModelOutput, ModelResult
49
from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties
510

6-
import torch.nn as nn
7-
from typing import Optional
8-
911

1012
def get_input_channels(state_dict: dict) -> int:
1113
"""
@@ -131,7 +133,10 @@ def from_torch(
131133

132134
# Model weights asset
133135
assets["model"] = Asset(
134-
title=f"{meta.get('model', 'Model')} ({meta.get('encoder', '')}) weights trained on {meta.get('dataset', 'dataset')} dataset",
136+
title=(
137+
f"{meta.get('model', 'Model')} ({meta.get('encoder', '')}) weights "
138+
f"trained on {meta.get('dataset', 'dataset')} dataset"
139+
),
135140
description=(
136141
f"A {meta.get('model', 'Model')} segmentation model with {meta.get('encoder', '')} encoder "
137142
f"trained on {meta.get('dataset', 'dataset')} imagery with {meta.get('num_classes', '?')}-class labels. "
@@ -194,4 +199,17 @@ def from_torch(
194199
ext = MLModelExtension.ext(item, add_if_missing=True)
195200
ext.apply(mlm_props)
196201

197-
return ItemMLModelExtension(item)
202+
eo_model_asset = cast(
203+
EOExtension[Asset],
204+
EOExtension.ext(assets["model"], add_if_missing=True),
205+
)
206+
eo_bands = []
207+
for name in bands:
208+
band = Band({})
209+
band.apply(name=name)
210+
eo_bands.append(band)
211+
eo_model_asset.apply(bands=eo_bands)
212+
213+
item_mlm = MLModelExtension.ext(item, add_if_missing=True)
214+
item_mlm.apply(mlm_props.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True))
215+
return item_mlm

0 commit comments

Comments
 (0)