Skip to content

feat(core): Support eq_s and hash_s for opaque objects #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions cpp/structure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,24 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
} else if (lhs_type_index == kMLCFunc || lhs_type_index == kMLCError) {
throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", new_path);
} else if (lhs_type_index == kMLCOpaque) {
std::ostringstream err;
err << "Cannot compare `mlc.Opaque` of type: " << lhs->DynCast<OpaqueObj>()->opaque_type_name;
throw SEqualError(err.str().c_str(), new_path);
std::string func_name = "Opaque.eq_s.";
func_name += lhs->DynCast<OpaqueObj>()->opaque_type_name;
FuncObj *func = Func::GetGlobal(func_name.c_str(), true);
if (func == nullptr) {
std::ostringstream err;
err << "Cannot compare `mlc.Opaque` of type: " << lhs->DynCast<OpaqueObj>()->opaque_type_name << "; Use "
<< "`mlc.Func.register(\"" << func_name << "\")(eq_s_func)` to register a comparison method";
throw SEqualError(err.str().c_str(), new_path);
}
Any result = (*func)(lhs, rhs);
if (result.type_index != kMLCBool) {
std::ostringstream err;
err << "Comparison function `" << func_name << "` must return a boolean value, but got: " << result;
throw SEqualError(err.str().c_str(), new_path);
}
if (result.operator bool() == false) {
MLC_CORE_EQ_S_ERR(lhs, rhs, new_path);
}
} else {
bool visited = false;
MLCTypeInfo *type_info = Lib::GetTypeInfo(lhs_type_index);
Expand Down Expand Up @@ -802,9 +817,21 @@ inline uint64_t StructuralHashImpl(Object *obj) {
} else if (type_index == kMLCFunc || type_index == kMLCError) {
throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", ObjectPath::Root());
} else if (type_index == kMLCOpaque) {
std::ostringstream err;
err << "Cannot compare `mlc.Opaque` of type: " << obj->DynCast<OpaqueObj>()->opaque_type_name;
throw SEqualError(err.str().c_str(), ObjectPath::Root());
std::string func_name = "Opaque.hash_s.";
func_name += obj->DynCast<OpaqueObj>()->opaque_type_name;
FuncObj *func = Func::GetGlobal(func_name.c_str(), true);
if (func == nullptr) {
MLC_THROW(ValueError) << "Cannot hash `mlc.Opaque` of type: " << obj->DynCast<OpaqueObj>()->opaque_type_name
<< "; Use `mlc.Func.register(\"" << func_name
<< "\")(hash_s_func)` to register a hashing method";
}
Any result = (*func)(obj);
if (result.type_index != kMLCInt) {
MLC_THROW(TypeError) << "Hashing function `" << func_name
<< "` must return an integer value, but got: " << result;
}
int64_t hash_value = result.operator int64_t();
EnqueuePOD(tasks, hash_value);
} else {
MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index);
tasks->emplace_back(Task{obj, type_info, false, bind_free_vars, type_info->type_key_hash});
Expand Down
2 changes: 1 addition & 1 deletion include/mlc/core/func.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct FuncObj : public MLCFunc {
using SafeCall = int32_t(const FuncObj *, int32_t, const AnyView *, Any *);
struct Allocator;

template <typename... Args> MLC_INLINE Any operator()(Args &&...args) const {
template <typename... Args> inline Any operator()(Args &&...args) const {
constexpr size_t N = sizeof...(Args);
AnyViewArray<N> stack_args;
Any ret;
Expand Down
8 changes: 4 additions & 4 deletions python/mlc/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,10 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage):
y = (<PyAny>x)._mlc_any
elif isinstance(x, Str):
y = (<PyAny>(x._pyany))._mlc_any
elif isinstance(x, _OPAQUE_TYPES):
x = _pyany_from_opaque(x)
y = (<PyAny>x)._mlc_any
temporary_storage.append(x)
elif isinstance(x, bool):
y = _MLCAnyBool(<bint>x)
elif isinstance(x, Integral):
Expand All @@ -713,10 +717,6 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage):
x = _pyany_from_dlpack(x)
y = (<PyAny>x)._mlc_any
temporary_storage.append(x)
elif isinstance(x, _OPAQUE_TYPES):
x = _pyany_from_opaque(x)
y = (<PyAny>x)._mlc_any
temporary_storage.append(x)
else:
raise TypeError(f"MLC does not recognize type: {type(x)}")
return y
Expand Down
51 changes: 49 additions & 2 deletions tests/python/test_core_opaque.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
from typing import Any

import mlc
import pytest


class MyType:
class MyTypeNotRegistered:
def __init__(self, a: int) -> None:
self.a = a


class MyTypeNotRegistered:
class MyType:
def __init__(self, a: int) -> None:
self.a = a

def __call__(self, x: int) -> int:
return x + self.a


mlc.Opaque.register(MyType)


@mlc.dataclasses.py_class(structure="bind")
class Wrapper(mlc.dataclasses.PyClass):
field: Any = mlc.dataclasses.field(structure="nobind")


def test_opaque_init() -> None:
a = MyType(a=10)
opaque = mlc.Opaque(a)
Expand Down Expand Up @@ -47,3 +57,40 @@ def test_opaque_ffi_error() -> None:
str(e.value)
== "MLC does not recognize type: <class 'test_core_opaque.MyTypeNotRegistered'>"
)


def test_opaque_dataclass() -> None:
a = MyType(a=10)
wrapper = Wrapper(field=a)
assert isinstance(wrapper.field, MyType)
assert wrapper.field.a == 10


@mlc.Func.register("Opaque.eq_s.test_core_opaque.MyType")
def _eq_s_MyType(a: MyType, b: MyType) -> bool:
return isinstance(a, MyType) and isinstance(b, MyType) and a.a == b.a


@mlc.Func.register("Opaque.hash_s.test_core_opaque.MyType")
def _hash_s_MyType(a: MyType) -> int:
assert isinstance(a, MyType)
return hash((MyType, a.a))


def test_opaque_dataclass_eq_s() -> None:
a1 = Wrapper(field=MyType(a=10))
a2 = Wrapper(field=MyType(a=10))
a1.eq_s(a2, assert_mode=True)


def test_opaque_dataclass_eq_s_fail() -> None:
a1 = Wrapper(field=MyType(a=10))
a2 = Wrapper(field=MyType(a=20))
with pytest.raises(ValueError) as exc_info:
a1.eq_s(a2, assert_mode=True)
assert str(exc_info.value).startswith("Structural equality check failed at {root}.field")


def test_opaque_dataclass_hash_s() -> None:
a1 = Wrapper(field=MyType(a=10))
assert isinstance(a1.hash_s(), int)
Loading