|
16 | 16 |
|
17 | 17 | import copy |
18 | 18 | import json |
| 19 | +import math |
19 | 20 | import os |
20 | 21 | import warnings |
21 | 22 | from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union |
|
50 | 51 | # type hinting: specifying the type of config class that inherits from PreTrainedConfig |
51 | 52 | SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig") |
52 | 53 |
|
| 54 | +_FLOAT_TAG_KEY = "__float__" |
| 55 | +_FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")} |
| 56 | + |
53 | 57 |
|
54 | 58 | class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin): |
55 | 59 | # no-format |
@@ -812,7 +816,56 @@ def from_json_file( |
812 | 816 | def _dict_from_json_file(cls, json_file: str | os.PathLike): |
813 | 817 | with open(json_file, encoding="utf-8") as reader: |
814 | 818 | text = reader.read() |
815 | | - return json.loads(text) |
| 819 | + config_dict = json.loads(text) |
| 820 | + |
| 821 | + return cls._decode_special_floats(config_dict) |
| 822 | + |
| 823 | + @classmethod |
| 824 | + def _encode_special_floats(cls, obj: Any) -> Any: |
| 825 | + """ |
| 826 | + Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON |
| 827 | + engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines. |
| 828 | +
|
| 829 | + It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`. |
| 830 | + """ |
| 831 | + if isinstance(obj, float): |
| 832 | + if math.isnan(obj): |
| 833 | + return {_FLOAT_TAG_KEY: "NaN"} |
| 834 | + if obj == float("inf"): |
| 835 | + return {_FLOAT_TAG_KEY: "Infinity"} |
| 836 | + if obj == float("-inf"): |
| 837 | + return {_FLOAT_TAG_KEY: "-Infinity"} |
| 838 | + return obj |
| 839 | + |
| 840 | + if isinstance(obj, dict): |
| 841 | + return {k: cls._encode_special_floats(v) for k, v in obj.items()} |
| 842 | + |
| 843 | + if isinstance(obj, (list, tuple)): |
| 844 | + return [cls._encode_special_floats(v) for v in obj] |
| 845 | + |
| 846 | + return obj |
| 847 | + |
| 848 | + @classmethod |
| 849 | + def _decode_special_floats(cls, obj: Any) -> Any: |
| 850 | + """ |
| 851 | + Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON |
| 852 | + engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines. |
| 853 | +
|
| 854 | + This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`. |
| 855 | + """ |
| 856 | + if isinstance(obj, dict): |
| 857 | + if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str): |
| 858 | + tag = obj[_FLOAT_TAG_KEY] |
| 859 | + if tag in _FLOAT_TAG_VALUES: |
| 860 | + return _FLOAT_TAG_VALUES[tag] |
| 861 | + return obj |
| 862 | + |
| 863 | + return {k: cls._decode_special_floats(v) for k, v in obj.items()} |
| 864 | + |
| 865 | + if isinstance(obj, list): |
| 866 | + return [cls._decode_special_floats(v) for v in obj] |
| 867 | + |
| 868 | + return obj |
816 | 869 |
|
817 | 870 | def __eq__(self, other): |
818 | 871 | return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__) |
@@ -932,6 +985,10 @@ def to_json_string(self, use_diff: bool = True) -> str: |
932 | 985 | config_dict = self.to_diff_dict() |
933 | 986 | else: |
934 | 987 | config_dict = self.to_dict() |
| 988 | + |
| 989 | + # Handle +/-Infinity and NaNs |
| 990 | + config_dict = self._encode_special_floats(config_dict) |
| 991 | + |
935 | 992 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" |
936 | 993 |
|
937 | 994 | def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True): |
|
0 commit comments