Skip to content

Commit d14d99e

Browse files
authored
Fix infinity in JSON serialized files (#42959)
* Handle inifinity and NaNs in JSON serialization * Docs * Tests
1 parent d54d78f commit d14d99e

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

src/transformers/configuration_utils.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import copy
1818
import json
19+
import math
1920
import os
2021
import warnings
2122
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
@@ -50,6 +51,9 @@
5051
# type hinting: specifying the type of config class that inherits from PreTrainedConfig
5152
SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
5253

54+
_FLOAT_TAG_KEY = "__float__"
55+
_FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")}
56+
5357

5458
class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
5559
# no-format
@@ -812,7 +816,56 @@ def from_json_file(
812816
def _dict_from_json_file(cls, json_file: str | os.PathLike):
813817
with open(json_file, encoding="utf-8") as reader:
814818
text = reader.read()
815-
return json.loads(text)
819+
config_dict = json.loads(text)
820+
821+
return cls._decode_special_floats(config_dict)
822+
823+
@classmethod
824+
def _encode_special_floats(cls, obj: Any) -> Any:
825+
"""
826+
Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON
827+
engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
828+
829+
It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`.
830+
"""
831+
if isinstance(obj, float):
832+
if math.isnan(obj):
833+
return {_FLOAT_TAG_KEY: "NaN"}
834+
if obj == float("inf"):
835+
return {_FLOAT_TAG_KEY: "Infinity"}
836+
if obj == float("-inf"):
837+
return {_FLOAT_TAG_KEY: "-Infinity"}
838+
return obj
839+
840+
if isinstance(obj, dict):
841+
return {k: cls._encode_special_floats(v) for k, v in obj.items()}
842+
843+
if isinstance(obj, (list, tuple)):
844+
return [cls._encode_special_floats(v) for v in obj]
845+
846+
return obj
847+
848+
@classmethod
849+
def _decode_special_floats(cls, obj: Any) -> Any:
850+
"""
851+
Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON
852+
engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
853+
854+
This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`.
855+
"""
856+
if isinstance(obj, dict):
857+
if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str):
858+
tag = obj[_FLOAT_TAG_KEY]
859+
if tag in _FLOAT_TAG_VALUES:
860+
return _FLOAT_TAG_VALUES[tag]
861+
return obj
862+
863+
return {k: cls._decode_special_floats(v) for k, v in obj.items()}
864+
865+
if isinstance(obj, list):
866+
return [cls._decode_special_floats(v) for v in obj]
867+
868+
return obj
816869

817870
def __eq__(self, other):
818871
return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__)
@@ -932,6 +985,10 @@ def to_json_string(self, use_diff: bool = True) -> str:
932985
config_dict = self.to_diff_dict()
933986
else:
934987
config_dict = self.to_dict()
988+
989+
# Handle +/-Infinity and NaNs
990+
config_dict = self._encode_special_floats(config_dict)
991+
935992
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
936993

937994
def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):

tests/utils/test_configuration_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,44 @@ def test_bc_torch_dtype(self):
329329

330330
config = PreTrainedConfig.from_pretrained(tmpdirname, torch_dtype="float32")
331331
self.assertEqual(config.dtype, "float32")
332+
333+
def test_unserializable_json_is_encoded(self):
334+
class NewConfig(PreTrainedConfig):
335+
def __init__(
336+
self,
337+
inf_positive: float = float("inf"),
338+
inf_negative: float = float("-inf"),
339+
nan: float = float("nan"),
340+
**kwargs,
341+
):
342+
self.inf_positive = inf_positive
343+
self.inf_negative = inf_negative
344+
self.nan = nan
345+
346+
super().__init__(**kwargs)
347+
348+
new_config = NewConfig()
349+
350+
# All floats should remain as floats when being accessed in the config
351+
self.assertIsInstance(new_config.inf_positive, float)
352+
self.assertIsInstance(new_config.inf_negative, float)
353+
self.assertIsInstance(new_config.nan, float)
354+
355+
with tempfile.TemporaryDirectory() as tmpdirname:
356+
new_config.save_pretrained(tmpdirname)
357+
config_file = Path(tmpdirname) / "config.json"
358+
config_contents = json.loads(config_file.read_text())
359+
new_config_instance = NewConfig.from_pretrained(tmpdirname)
360+
361+
# In the serialized JSON file, the non-JSON compatible floats should be updated
362+
self.assertDictEqual(config_contents["inf_positive"], {"__float__": "Infinity"})
363+
self.assertDictEqual(config_contents["inf_negative"], {"__float__": "-Infinity"})
364+
self.assertDictEqual(config_contents["nan"], {"__float__": "NaN"})
365+
366+
with tempfile.TemporaryDirectory() as tmpdirname:
367+
new_config.save_pretrained(tmpdirname)
368+
369+
# When reloading the config, it should have correct float values
370+
self.assertIsInstance(new_config_instance.inf_positive, float)
371+
self.assertIsInstance(new_config_instance.inf_negative, float)
372+
self.assertIsInstance(new_config_instance.nan, float)

0 commit comments

Comments
 (0)