Skip to content

Commit 63df203

Browse files
committed
Using enum class for task
1 parent 965c99a commit 63df203

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

stac_model/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pystac import Asset, Item, Link
55
from pystac.extensions.eo import Band, EOExtension
66

7+
from stac_model.base import TaskEnum
78
from stac_model.input import InputStructure, ModelInput
89
from stac_model.output import MLMClassification, ModelOutput, ModelResult
910
from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties
@@ -44,7 +45,7 @@ def from_torch(
4445
) -> ItemMLModelExtension:
4546
total_params = sum(p.numel() for p in model.parameters())
4647
arch = f"{model.__class__.__module__}.{model.__class__.__name__}"
47-
task = {"classification"}
48+
task = {TaskEnum.CLASSIFICATION}
4849

4950
# Extra metadata only found in weights of torchgeo models
5051
has_meta = weights is not None and hasattr(weights, "meta")
@@ -76,7 +77,6 @@ def from_torch(
7677
bands=bands,
7778
input=input_struct,
7879
resize_type=None,
79-
value_scaling=None,
8080
pre_processing_function=None,
8181
)
8282

0 commit comments

Comments
 (0)