Skip to content

feat(core): Support serialization of opaque objects #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: CI
on: [push, pull_request]
env:
CIBW_BUILD_VERBOSITY: 3
CIBW_TEST_REQUIRES: "pytest torch"
CIBW_TEST_REQUIRES: "pytest torch jsonpickle"
CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/"
CIBW_ENVIRONMENT: "MLC_SHOW_CPP_STACKTRACES=1"
CIBW_REPAIR_WHEEL_COMMAND_LINUX: >
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
rev: "v1.14.1"
hooks:
- id: mypy
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch"]
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch", "jsonpickle"]
args: [--show-error-codes]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: "v19.1.6"
Expand Down
4 changes: 2 additions & 2 deletions cpp/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace mlc {
namespace registry {

Any JSONLoads(AnyView json_str);
Any JSONDeserialize(AnyView json_str);
Str JSONSerialize(AnyView source);
Any JSONDeserialize(AnyView json_str, FuncObj *fn_opaque_deserialize);
Str JSONSerialize(AnyView source, FuncObj *fn_opaque_serialize);
bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_mode);
int64_t StructuralHash(AnyView root);
Optional<Str> StructuralEqualFailReason(AnyView lhs, AnyView rhs, bool bind_free_vars);
Expand Down
196 changes: 141 additions & 55 deletions cpp/structure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>

namespace mlc {
Expand Down Expand Up @@ -1300,8 +1301,11 @@ Tensor TensorFromBytes(const uint8_t *data_ptr, int64_t max_size) {

/****************** Serialize / Deserialize ******************/

inline mlc::Str Serialize(Any any) {
inline mlc::Str Serialize(Any any, FuncObj *fn_opaque_serialize) {
using mlc::base::TypeTraits;
// Section 1. Define two lookups
// 1) `type_keys` and `get_json_type_index`, which maps a `type_key` to type index/key to that in JSON
// 2) `opaques` and `get_opaque_index`, which maps an `OpaqueObj` to its index in the list of opaques
std::vector<const char *> type_keys;
auto get_json_type_index = [type_key2index = std::unordered_map<const char *, int32_t>(),
&type_keys](const char *type_key) mutable -> int32_t {
Expand All @@ -1313,8 +1317,27 @@ inline mlc::Str Serialize(Any any) {
type_keys.push_back(type_key);
return type_index;
};
using TObj2Idx = std::unordered_map<Object *, int32_t>;
using TJsonTypeIndex = decltype(get_json_type_index);
using TGetJSONTypeIndex = decltype(get_json_type_index);
struct OpaqueHash {
size_t operator()(const OpaqueObj *opaque) const { return std::hash<void *>{}(opaque->handle); }
};
struct OpaqueEq {
bool operator()(const OpaqueObj *lhs, const OpaqueObj *rhs) const { return lhs->handle == rhs->handle; }
};
UList opaques;
auto get_opaque_index = [opaque2index = std::unordered_map<const OpaqueObj *, int32_t, OpaqueHash, OpaqueEq>(),
&opaques](const OpaqueObj *opaque) mutable -> int32_t {
if (auto it = opaque2index.find(opaque); it != opaque2index.end()) {
return it->second;
}
int32_t type_index = static_cast<int32_t>(opaque2index.size());
opaque2index[opaque] = type_index;
opaques.push_back(opaque);
return type_index;
};
// Section 2. Define `Emitter`, which emits a singleton type of:
// - POD: bool, int, float, string, DLDataType, DLDevice, void*
// - Object that has been known previously
struct Emitter {
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
// clang-format off
Expand Down Expand Up @@ -1381,57 +1404,77 @@ inline mlc::Str Serialize(Any any) {
if (!obj) {
MLC_THROW(InternalError) << "This should never happen: null object pointer during EmitObject";
}
int32_t obj_idx = obj2index->at(obj);
int32_t obj_idx = topo_index->at(obj);
if (obj_idx == -1) {
MLC_THROW(InternalError) << "This should never happen: topological ordering violated";
}
(*os) << ", " << obj_idx;
}
std::ostringstream *os;
TJsonTypeIndex *get_json_type_index;
const TObj2Idx *obj2index;
TGetJSONTypeIndex *get_json_type_index;
const std::unordered_map<Object *, int32_t> *topo_index;
};

std::unordered_map<Object *, int32_t> topo_indices;
// Section 3. Define `on_visit` method for topological traversal of the graph
// Inside `values` section, every `on_visit` generates one of the following:
// 1) string `s`: for string literal `s`
// 2) list -> normal case, its layout is:
// - [0] = json_type_index
// - [1...] = each field of the type
// * int: refer to `values[i]`
// * list: TODO: explain
// * str / bool / float / None: literals
std::vector<TensorObj *> tensors;
std::ostringstream os;
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, &tensors,
is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void {
auto on_visit =
[get_json_type_index = &get_json_type_index, os = &os, &tensors, &get_opaque_index, is_first_object = true,
topo_indices = std::unordered_map<Object *, int32_t>()](Object *object, MLCTypeInfo *type_info) mutable -> void {
// Step 1. Allocate `topo_index` assigned to the current object
int32_t &topo_index = topo_indices[object];
if (topo_index == 0) {
topo_index = static_cast<int32_t>(topo_indices.size()) - 1;
} else {
MLC_THROW(InternalError) << "This should never happen: object already visited";
}
Emitter emitter{os, get_json_type_index, &topo_indices};
if (is_first_object) {
is_first_object = false;
} else {
os->put(',');
}
// Step 2. Print the current object
// Special case: string
if (StrObj *str = object->as<StrObj>()) {
str->PrintEscape(*os);
return;
}
// [0] = json_type_index
(*os) << '[' << (*get_json_type_index)(type_info->type_key);
// [1...] = each field of the type. A few possible cases:
// 1) list
// 2) dict
// 3) tensor
// 4) opaque
// 5) a normal dataclass
if (UListObj *list = object->as<UListObj>()) {
Emitter emitter{os, get_json_type_index, &topo_indices};
for (Any &any : *list) {
emitter(nullptr, &any);
emitter.EmitAny(&any);
}
} else if (UDictObj *dict = object->as<UDictObj>()) {
Emitter emitter{os, get_json_type_index, &topo_indices};
for (auto &kv : *dict) {
emitter(nullptr, &kv.first);
emitter(nullptr, &kv.second);
emitter.EmitAny(&kv.first);
emitter.EmitAny(&kv.second);
}
} else if (TensorObj *tensor = object->as<TensorObj>()) {
(*os) << ", " << tensors.size();
tensors.push_back(tensor);
} else if (OpaqueObj *opaque = object->as<OpaqueObj>()) {
int32_t opaque_index = get_opaque_index(opaque);
(*os) << ", " << opaque_index;
} else if (object->IsInstance<FuncObj>() || object->IsInstance<ErrorObj>()) {
MLC_THROW(TypeError) << "Unserializable type: " << object->GetTypeKey();
} else if (object->IsInstance<OpaqueObj>()) {
MLC_THROW(TypeError) << "Cannot serialize `mlc.Opaque` of type: "
<< object->DynCast<OpaqueObj>()->opaque_type_name;
} else {
Emitter emitter{os, get_json_type_index, &topo_indices};
VisitFields(object, type_info, emitter);
}
os->put(']');
Expand Down Expand Up @@ -1481,12 +1524,25 @@ inline mlc::Str Serialize(Any any) {
}
os << "]";
}
if (!opaques.empty()) {
os << ", \"opaques\":";
if (!fn_opaque_serialize) {
fn_opaque_serialize = Func::GetGlobal("mlc.Opaque.default.serialize", true);
}
if (!fn_opaque_serialize) {
MLC_THROW(ValueError) << "Cannot find serialization function `mlc.Opaque.default.serialize`. Register it with "
"`mlc.Func.register(\"mlc.Opaque.default.serialize\")(serialize_func)`";
}
Str opaque_repr = (*fn_opaque_serialize)(opaques);
opaque_repr->PrintEscape(os);
}
os << "}";
return os.str();
}

inline Any Deserialize(const char *json_str, int64_t json_str_len) {
inline Any Deserialize(const char *json_str, int64_t json_str_len, FuncObj *fn_opaque_deserialize) {
int32_t json_type_index_tensor = -1;
int32_t json_type_index_opaque = -1;
// Step 0. Parse JSON string
UDict json_obj = JSONLoads(json_str, json_str_len);
// Step 1. type_key => constructors
Expand All @@ -1496,10 +1552,12 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
for (Str type_key : type_keys) {
int32_t type_index = Lib::GetTypeIndex(type_key->data());
FuncObj *func = nullptr;
if (type_index != kMLCTensor) {
func = Lib::_init(type_index);
} else {
if (type_index == kMLCTensor) {
json_type_index_tensor = static_cast<int32_t>(constructors.size());
} else if (type_index == kMLCOpaque) {
json_type_index_opaque = static_cast<int32_t>(constructors.size());
} else {
func = Lib::_init(type_index);
}
constructors.push_back(func);
}
Expand All @@ -1522,45 +1580,71 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
json_obj->erase("tensors");
std::reverse(tensors.begin(), tensors.end());
}
// Step 3. Translate JSON object to objects
// Step 3. Handle opaque objects
UList opaques;
if (json_obj.count("opaques")) {
if (!fn_opaque_deserialize) {
fn_opaque_deserialize = Func::GetGlobal("mlc.Opaque.default.deserialize", true);
}
if (!fn_opaque_deserialize) {
MLC_THROW(ValueError)
<< "Cannot find deserialization function `mlc.Opaque.default.deserialize`. Register it with "
"`mlc.Func.register(\"mlc.Opaque.default.deserialize\")(deserialize_func)`";
}
opaques = (*fn_opaque_deserialize)(json_obj->at("opaques")).operator UList();
}
// Step 4. Translate JSON object to objects
UList values = json_obj->at("values");
for (int64_t i = 0; i < values->size(); ++i) {
Any obj = values[i];
if (obj.type_index == kMLCList) {
UList list = obj.operator UList();
int32_t json_type_index = list[0];
Any &value = values[i];
// every `value` is
// 1) integer `i` -> refer to `values[i]`
// 2) string `s` -> string literal `s`
// 3) list -> normal case, its layout is:
// - [0] = json_type_index
// - [1...] = each field of the type
// * int: refer to `values[i]`
// * list: TODO: explain
// * str / bool / float / None: literals
// TODO: how about kMLCBool, kMLCFloat, kMLCNone?
if (UListObj *list = value.as<UListObj>()) {
// Layout of the list:
int32_t json_type_index = (*list)[0];
if (json_type_index == json_type_index_tensor) {
values[i] = tensors[list[1].operator int32_t()];
continue;
}
for (int64_t j = 1; j < list.size(); ++j) {
Any arg = list[j];
if (arg.type_index == kMLCInt) {
int64_t k = arg;
if (k < i) {
list[j] = values[k];
int32_t idx = (*list)[1];
value = tensors[idx];
} else if (json_type_index == json_type_index_opaque) {
int32_t idx = (*list)[1];
value = opaques[idx];
} else {
for (int64_t j = 1; j < list->size(); ++j) {
Any arg = (*list)[j];
if (arg.type_index == kMLCInt) {
int64_t k = arg;
if (k < i) {
(*list)[j] = values[k];
} else {
MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index]
<< "`: referring #" << k << " at #" << i << ". v = " << value;
}
} else if (arg.type_index == kMLCList) {
(*list)[j] = invoke_init(arg.operator UList());
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
arg.type_index == kMLCNone) {
// Do nothing
} else {
MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index]
<< "`: referring #" << k << " at #" << i << ". v = " << obj;
MLC_THROW(ValueError) << "Unexpected value: " << arg;
}
} else if (arg.type_index == kMLCList) {
list[j] = invoke_init(arg.operator UList());
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
arg.type_index == kMLCNone) {
// Do nothing
} else {
MLC_THROW(ValueError) << "Unexpected value: " << arg;
}
value = invoke_init(UList(list));
}
values[i] = invoke_init(list);
} else if (obj.type_index == kMLCInt) {
int32_t k = obj;
values[i] = values[k];
} else if (obj.type_index == kMLCStr) {
} else if (value.type_index == kMLCInt) {
int32_t k = value;
value = values[k];
} else if (value.type_index == kMLCStr) {
// Do nothing
// TODO: how about kMLCBool, kMLCFloat, kMLCNone?
} else {
MLC_THROW(ValueError) << "Unexpected value: " << obj;
MLC_THROW(ValueError) << "Unexpected value: " << value;
}
}
return values->back();
Expand Down Expand Up @@ -1617,16 +1701,18 @@ Any JSONLoads(AnyView json_str) {
}
}

Any JSONDeserialize(AnyView json_str) {
Any JSONDeserialize(AnyView json_str, FuncObj *fn_opaque_deserialize) {
if (json_str.type_index == kMLCRawStr) {
return ::mlc::Deserialize(json_str.operator const char *(), -1);
return ::mlc::Deserialize(json_str.operator const char *(), -1, fn_opaque_deserialize);
} else {
StrObj *js = json_str.operator StrObj *();
return ::mlc::Deserialize(js->data(), js->size());
return ::mlc::Deserialize(js->data(), js->size(), fn_opaque_deserialize);
}
}

Str JSONSerialize(AnyView source) { return ::mlc::Serialize(source); }
Str JSONSerialize(AnyView source, FuncObj *fn_opaque_serialize) {
return ::mlc::Serialize(source, fn_opaque_serialize);
}

Str TensorToBytes(const TensorObj *src) {
return ::mlc::TensorToBytes(&src->tensor); //
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ authors = [{ name = "MLC Authors", email = "[email protected]" }]
"mlc.config" = "mlc.config:main"

[project.optional-dependencies]
tests = ['pytest', 'torch']
tests = ['pytest', 'torch', 'jsonpickle']
dev = [
"cython>=3.1",
"pre-commit",
Expand All @@ -35,6 +35,7 @@ dev = [
"ruff",
"mypy",
"torch",
"jsonpickle",
]

[build-system]
Expand Down
16 changes: 8 additions & 8 deletions python/mlc/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -360,25 +360,25 @@ cdef class PyAny:
return (base.new_object, (type(self),), self.__getstate__())

def __getstate__(self):
return {"mlc_json": func_call(_SERIALIZE, (self,))}
return {"mlc_json": func_call(_SERIALIZE, (self, None))}

def __setstate__(self, state):
cdef PyAny ret = func_call(_DESERIALIZE, (state["mlc_json"], ))
cdef PyAny ret = func_call(_DESERIALIZE, (state["mlc_json"], None))
cdef MLCAny tmp = self._mlc_any
self._mlc_any = ret._mlc_any
ret._mlc_any = tmp

def _mlc_json(self):
return func_call(_SERIALIZE, (self,))
def _mlc_json(self, fn_opaque_serialize):
return func_call(_SERIALIZE, (self, fn_opaque_serialize))

def _mlc_swap(self, PyAny other):
cdef MLCAny tmp = self._mlc_any
self._mlc_any = other._mlc_any
other._mlc_any = tmp

@staticmethod
def _mlc_from_json(mlc_json):
return func_call(_DESERIALIZE, (mlc_json,))
def _mlc_from_json(mlc_json, fn_opaque_deserialize):
return func_call(_DESERIALIZE, (mlc_json, fn_opaque_deserialize))

@staticmethod
def _mlc_eq_s(PyAny lhs, PyAny rhs, bint bind_free_vars, bint assert_mode) -> bool:
Expand Down Expand Up @@ -1442,8 +1442,8 @@ cpdef void func_init(PyAny self, object callable):
self._mlc_any = ret._mlc_any
ret._mlc_any = _MLCAnyNone()

cpdef void opaque_init(PyAny self, object callable):
cdef PyAny ret = _pyany_from_opaque(callable)
cpdef void opaque_init(PyAny self, object opaque):
cdef PyAny ret = _pyany_from_opaque(opaque)
self._mlc_any = ret._mlc_any
ret._mlc_any = _MLCAnyNone()

Expand Down
Loading
Loading