Skip to content

Commit 0db8b95

Browse files
trajeplDavitGrigoryan132
authored andcommitted
📽️ Olive StrEnumBase IntEnumBase (microsoft#1290)
## Describe your changes python/cpython#100458 Quarot require python>=3.11 where the mixin usage of (str, Enum) did not work. This PR is used to create olive strEnum based on python version. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link
1 parent 5668cbf commit 0db8b95

File tree

21 files changed

+82
-64
lines changed

21 files changed

+82
-64
lines changed

examples/utils/generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
from enum import Enum
65
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
76

87
import numpy as np
@@ -11,13 +10,15 @@
1110
from onnxruntime import InferenceSession, OrtValue, SessionOptions
1211
from transformers import PreTrainedTokenizer
1312

13+
from olive.common.utils import StrEnumBase
14+
1415
if TYPE_CHECKING:
1516
from kv_cache_utils import Cache, IOBoundCache
1617
from numpy.typing import NDArray
1718
from onnx import ValueInfoProto
1819

1920

20-
class AdapterMode(Enum):
21+
class AdapterMode(StrEnumBase):
2122
"""Enum for adapter modes."""
2223

2324
inputs = "inputs"

olive/auto_optimizer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
import logging
77
from copy import deepcopy
8-
from enum import Enum
98
from typing import List, Optional
109

1110
from olive.auto_optimizer.regulate_mixins import RegulatePassConfigMixin
1211
from olive.common.config_utils import ConfigBase
1312
from olive.common.pydantic_v1 import validator
13+
from olive.common.utils import StrEnumBase
1414
from olive.data.config import DataConfig
1515
from olive.evaluator.olive_evaluator import OliveEvaluatorConfig
1616
from olive.hardware.accelerator import AcceleratorSpec
@@ -19,7 +19,7 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22-
class Precision(str, Enum):
22+
class Precision(StrEnumBase):
2323
FP32 = "fp32"
2424
FP16 = "fp16"
2525
INT8 = "int8"

olive/common/config_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import inspect
66
import json
77
import logging
8-
from enum import Enum
98
from functools import partial
109
from pathlib import Path
1110
from types import FunctionType, MethodType
@@ -14,7 +13,7 @@
1413
import yaml
1514

1615
from olive.common.pydantic_v1 import BaseModel, create_model, root_validator, validator
17-
from olive.common.utils import hash_function, hash_object
16+
from olive.common.utils import StrEnumBase, hash_function, hash_object
1817

1918
logger = logging.getLogger(__name__)
2019

@@ -212,7 +211,7 @@ def gather_nested_field(cls, values):
212211
return values
213212

214213

215-
class CaseInsensitiveEnum(str, Enum):
214+
class CaseInsensitiveEnum(StrEnumBase):
216215
"""StrEnum class that is insensitive to the case of the input string.
217216
218217
Note: Only insensitive when creating the enum object like `CaseInsensitiveEnum("value")`.

olive/common/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
from enum import Enum
5+
from olive.common.utils import StrEnumBase
66

77

8-
class OS(str, Enum):
8+
class OS(StrEnumBase):
99
WINDOWS = "Windows"
1010
LINUX = "Linux"
1111

olive/common/utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,39 @@
1313
import shlex
1414
import shutil
1515
import subprocess
16+
import sys
1617
import tempfile
1718
import time
1819
from pathlib import Path
1920
from typing import Dict, List, Optional, Tuple, Union
2021

21-
from olive.common.constants import OS
22-
2322
logger = logging.getLogger(__name__)
2423

2524

25+
if sys.version_info >= (3, 11):
26+
from enum import IntEnum, StrEnum
27+
28+
class StrEnumBase(StrEnum):
29+
pass
30+
31+
class IntEnumBase(IntEnum):
32+
pass
33+
34+
else:
35+
from enum import Enum
36+
37+
class StrEnumBase(str, Enum):
38+
pass
39+
40+
class IntEnumBase(int, Enum):
41+
pass
42+
43+
2644
def run_subprocess(cmd, env=None, cwd=None, check=False):
2745
logger.debug("Running command: %s", cmd)
2846

2947
assert isinstance(cmd, (str, list)), f"cmd must be a string or a list, got {type(cmd)}."
30-
windows = platform.system() == OS.WINDOWS
48+
windows = platform.system() == "Windows"
3149
if isinstance(cmd, str):
3250
# In posix model, the cmd string will be handled with specific posix rules.
3351
# https://docs.python.org/3.8/library/shlex.html#parsing-rules

olive/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
from enum import Enum
5+
from olive.common.utils import StrEnumBase
66

77

8-
class Framework(str, Enum):
8+
class Framework(StrEnumBase):
99
"""Framework of the model."""
1010

1111
ONNX = "ONNX"
@@ -16,7 +16,7 @@ class Framework(str, Enum):
1616
OPENVINO = "OpenVINO"
1717

1818

19-
class ModelFileFormat(str, Enum):
19+
class ModelFileFormat(StrEnumBase):
2020
"""Given a framework, there might be 1 or more on-disk model file format(s), model save/Load logic may differ."""
2121

2222
ONNX = "ONNX"

olive/data/component/text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6-
from enum import Enum
76
from pathlib import Path
87
from random import Random
98
from typing import Callable, Dict, List, Optional, Union
@@ -14,11 +13,12 @@
1413
from olive.common.config_utils import ConfigBase, validate_config, validate_object
1514
from olive.common.pydantic_v1 import validator
1615
from olive.common.user_module_loader import UserModuleLoader
16+
from olive.common.utils import StrEnumBase
1717
from olive.data.component.dataset import BaseDataset
1818
from olive.data.constants import IGNORE_INDEX
1919

2020

21-
class TextGenStrategy(str, Enum):
21+
class TextGenStrategy(StrEnumBase):
2222
"""Strategy for tokenizing a dataset."""
2323

2424
LINE_BY_LINE = "line-by-line" # each line is a sequence, in order of appearance

olive/data/constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6-
from enum import Enum
6+
from olive.common.utils import StrEnumBase
77

88
# index for targets that should be ignored when computing metrics
99
IGNORE_INDEX = -100
1010

1111

12-
class DataComponentType(Enum):
12+
class DataComponentType(StrEnumBase):
1313
"""enumerate for the different types of data components."""
1414

1515
# dataset component type: to load data into memory
@@ -22,13 +22,13 @@ class DataComponentType(Enum):
2222
DATALOADER = "dataloader"
2323

2424

25-
class DataContainerType(Enum):
25+
class DataContainerType(StrEnumBase):
2626
"""enumerate for the different types of data containers."""
2727

2828
DATA_CONTAINER = "data_container"
2929

3030

31-
class DefaultDataComponent(Enum):
31+
class DefaultDataComponent(StrEnumBase):
3232
"""enumerate for the default data components."""
3333

3434
LOAD_DATASET = "default_load_dataset"
@@ -37,7 +37,7 @@ class DefaultDataComponent(Enum):
3737
DATALOADER = "default_dataloader"
3838

3939

40-
class DefaultDataContainer(Enum):
40+
class DefaultDataContainer(StrEnumBase):
4141
"""enumerate for the default data containers."""
4242

4343
DATA_CONTAINER = "DataContainer"

olive/engine/packaging/packaging_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
from enum import Enum
65
from typing import Optional, Union
76

87
from olive.common.config_utils import CaseInsensitiveEnum, ConfigBase, NestedConfig, validate_config
98
from olive.common.constants import BASE_IMAGE
109
from olive.common.pydantic_v1 import validator
10+
from olive.common.utils import StrEnumBase
1111

1212

1313
class PackagingType(CaseInsensitiveEnum):
@@ -43,7 +43,7 @@ class DockerfilePackagingConfig(CommonPackagingConfig):
4343
requirements_file: Optional[str] = None
4444

4545

46-
class InferencingServerType(str, Enum):
46+
class InferencingServerType(StrEnumBase):
4747
AzureMLOnline = "AzureMLOnline"
4848
AzureMLBatch = "AzureMLBatch"
4949

@@ -54,7 +54,7 @@ class InferenceServerConfig(ConfigBase):
5454
scoring_script: str
5555

5656

57-
class AzureMLModelModeType(str, Enum):
57+
class AzureMLModelModeType(StrEnumBase):
5858
download = "download"
5959
copy = "copy"
6060

olive/evaluator/metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,27 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55
import logging
6-
from enum import Enum
76
from typing import Any, Dict, List, Optional, Union
87

98
from olive.common.config_utils import ConfigBase, validate_config
109
from olive.common.pydantic_v1 import validator
10+
from olive.common.utils import StrEnumBase
1111
from olive.data.config import DataConfig
1212
from olive.evaluator.accuracy import AccuracyBase
1313
from olive.evaluator.metric_config import LatencyMetricConfig, MetricGoal, ThroughputMetricConfig, get_user_config_class
1414

1515
logger = logging.getLogger(__name__)
1616

1717

18-
class MetricType(str, Enum):
18+
class MetricType(StrEnumBase):
1919
# TODO(trajep): support throughput
2020
ACCURACY = "accuracy"
2121
LATENCY = "latency"
2222
THROUGHPUT = "throughput"
2323
CUSTOM = "custom"
2424

2525

26-
class AccuracySubType(str, Enum):
26+
class AccuracySubType(StrEnumBase):
2727
ACCURACY_SCORE = "accuracy_score"
2828
F1_SCORE = "f1_score"
2929
PRECISION = "precision"
@@ -32,7 +32,7 @@ class AccuracySubType(str, Enum):
3232
PERPLEXITY = "perplexity"
3333

3434

35-
class LatencySubType(str, Enum):
35+
class LatencySubType(StrEnumBase):
3636
# unit: millisecond
3737
AVG = "avg"
3838
MAX = "max"
@@ -45,7 +45,7 @@ class LatencySubType(str, Enum):
4545
P999 = "p999"
4646

4747

48-
class ThroughputSubType(str, Enum):
48+
class ThroughputSubType(StrEnumBase):
4949
# unit: token per second, tps
5050
AVG = "avg"
5151
MAX = "max"

0 commit comments

Comments
 (0)