Skip to content

Commit f8db43c

Browse files
[Data] Add serialization framework for preprocessors (ray-project#58321)
## Description This commit introduces a new serialization system for Ray Data preprocessors that improves maintainability, extensibility, and backward compatibility. Key changes: 1. New serialization infrastructure: - Add serialization_handlers.py with factory pattern for format handling - Implement CloudPickleSerializationHandler (primary format) - Support legacy PickleSerializationHandler for backward compatibility - Add format auto-detection via magic bytes (CPKL:) 2. New preprocessor base class: - Add SerializablePreprocessorBase abstract class - Define serialization interface via abstract methods: * _get_serializable_fields() / _set_serializable_fields() * _get_stats() / _set_stats() - Mark serialize() and deserialize() as @Final to prevent overrides 3. Preprocessor registration system: - Add version_support.py with @SerializablePreprocessor decorator - Enable versioned serialization with stable identifiers - Support class registration and lookup - Add UnknownPreprocessorError for missing types 4. Migrate preprocessors to new framework: - SimpleImputer - OrdinalEncoder - OneHotEncoder - MultiHotEncoder - LabelEncoder - Categorizer - StandardScaler - MinMaxScaler - MaxAbsScaler - RobustScaler 5. Enhanced Preprocessor base class: - Add get_input_columns() and get_output_columns() methods (for future use) - Add has_stats() (for future use) - Add type hints to __getstate__() and __setstate__() 6. Backward compatibility improvements to Concatenator for existing functionality: - Add __setstate__ override in Concatenator for flatten field - Handle missing fields gracefully during deserialization The new architecture makes it easier to: - Add new serialization formats without modifying core logic - Maintain backward compatibility with existing serialized data - Handle version migrations for preprocessor schemas - Register new preprocessors with stable identifiers --------- Signed-off-by: cem <cem@anyscale.com>
1 parent d50a275 commit f8db43c

File tree

14 files changed

+1886
-28
lines changed

14 files changed

+1886
-28
lines changed

doc/source/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@
135135
("py:class", ".*"),
136136
# Workaround for https://github.com/sphinx-doc/sphinx/issues/10974
137137
("py:obj", "ray\\.data\\.datasource\\.datasink\\.WriteReturnType"),
138+
# UnknownPreprocessorError is an internal exception not exported in public API
139+
("py:exc", "UnknownPreprocessorError"),
140+
("py:exc", "ray\\.data\\.preprocessors\\.version_support\\.UnknownPreprocessorError"),
138141
]
139142

140143
# Cache notebook outputs in _build/.jupyter_cache

doc/source/train/user-guides/data-loading-preprocessing.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ You can use this with Ray Train Trainers by applying them on the dataset before
502502

503503
.. testcode::
504504

505+
import base64
505506
import numpy as np
506507
from tempfile import TemporaryDirectory
507508

@@ -542,16 +543,22 @@ You can use this with Ray Train Trainers by applying them on the dataset before
542543
checkpoint=Checkpoint.from_directory(temp_dir),
543544
)
544545

546+
# Serialize the preprocessor. Since serialize() returns bytes,
547+
# convert to base64 string for JSON compatibility.
548+
serialized_preprocessor = base64.b64encode(scaler.serialize()).decode("ascii")
549+
545550
my_trainer = TorchTrainer(
546551
train_loop_per_worker,
547552
scaling_config=ScalingConfig(num_workers=2),
548553
datasets={"train": dataset},
549-
metadata={"preprocessor_pkl": scaler.serialize()},
554+
metadata={"preprocessor_pkl": serialized_preprocessor},
550555
)
551556

552557
# Get the fitted preprocessor back from the result metadata.
553558
metadata = my_trainer.fit().checkpoint.get_metadata()
554-
print(StandardScaler.deserialize(metadata["preprocessor_pkl"]))
559+
# Decode from base64 before deserializing
560+
serialized_data = base64.b64decode(metadata["preprocessor_pkl"])
561+
print(StandardScaler.deserialize(serialized_data))
555562

556563

557564
This example persists the fitted preprocessor using the ``Trainer(metadata={...})`` constructor argument. This arg specifies a dict that is available from ``TrainContext.get_metadata()`` and ``checkpoint.get_metadata()`` for checkpoints that the Trainer saves. This design enables the recreation of the fitted preprocessor for inference.

0 commit comments

Comments
 (0)