Skip to content

Commit a8aa1d6

Browse files
authored
feat(core): Support eq_s and hash_s for opaque objects (#74)
Should fix two feature requests in #73
1 parent 81b5daf commit a8aa1d6

File tree

4 files changed

+87
-13
lines changed

4 files changed

+87
-13
lines changed

cpp/structure.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,24 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
538538
} else if (lhs_type_index == kMLCFunc || lhs_type_index == kMLCError) {
539539
throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", new_path);
540540
} else if (lhs_type_index == kMLCOpaque) {
541-
std::ostringstream err;
542-
err << "Cannot compare `mlc.Opaque` of type: " << lhs->DynCast<OpaqueObj>()->opaque_type_name;
543-
throw SEqualError(err.str().c_str(), new_path);
541+
std::string func_name = "Opaque.eq_s.";
542+
func_name += lhs->DynCast<OpaqueObj>()->opaque_type_name;
543+
FuncObj *func = Func::GetGlobal(func_name.c_str(), true);
544+
if (func == nullptr) {
545+
std::ostringstream err;
546+
err << "Cannot compare `mlc.Opaque` of type: " << lhs->DynCast<OpaqueObj>()->opaque_type_name << "; Use "
547+
<< "`mlc.Func.register(\"" << func_name << "\")(eq_s_func)` to register a comparison method";
548+
throw SEqualError(err.str().c_str(), new_path);
549+
}
550+
Any result = (*func)(lhs, rhs);
551+
if (result.type_index != kMLCBool) {
552+
std::ostringstream err;
553+
err << "Comparison function `" << func_name << "` must return a boolean value, but got: " << result;
554+
throw SEqualError(err.str().c_str(), new_path);
555+
}
556+
if (result.operator bool() == false) {
557+
MLC_CORE_EQ_S_ERR(lhs, rhs, new_path);
558+
}
544559
} else {
545560
bool visited = false;
546561
MLCTypeInfo *type_info = Lib::GetTypeInfo(lhs_type_index);
@@ -802,9 +817,21 @@ inline uint64_t StructuralHashImpl(Object *obj) {
802817
} else if (type_index == kMLCFunc || type_index == kMLCError) {
803818
throw SEqualError("Cannot compare `mlc.Func` or `mlc.Error`", ObjectPath::Root());
804819
} else if (type_index == kMLCOpaque) {
805-
std::ostringstream err;
806-
err << "Cannot compare `mlc.Opaque` of type: " << obj->DynCast<OpaqueObj>()->opaque_type_name;
807-
throw SEqualError(err.str().c_str(), ObjectPath::Root());
820+
std::string func_name = "Opaque.hash_s.";
821+
func_name += obj->DynCast<OpaqueObj>()->opaque_type_name;
822+
FuncObj *func = Func::GetGlobal(func_name.c_str(), true);
823+
if (func == nullptr) {
824+
MLC_THROW(ValueError) << "Cannot hash `mlc.Opaque` of type: " << obj->DynCast<OpaqueObj>()->opaque_type_name
825+
<< "; Use `mlc.Func.register(\"" << func_name
826+
<< "\")(hash_s_func)` to register a hashing method";
827+
}
828+
Any result = (*func)(obj);
829+
if (result.type_index != kMLCInt) {
830+
MLC_THROW(TypeError) << "Hashing function `" << func_name
831+
<< "` must return an integer value, but got: " << result;
832+
}
833+
int64_t hash_value = result.operator int64_t();
834+
EnqueuePOD(tasks, hash_value);
808835
} else {
809836
MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index);
810837
tasks->emplace_back(Task{obj, type_info, false, bind_free_vars, type_info->type_key_hash});

include/mlc/core/func.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct FuncObj : public MLCFunc {
1515
using SafeCall = int32_t(const FuncObj *, int32_t, const AnyView *, Any *);
1616
struct Allocator;
1717

18-
template <typename... Args> MLC_INLINE Any operator()(Args &&...args) const {
18+
template <typename... Args> inline Any operator()(Args &&...args) const {
1919
constexpr size_t N = sizeof...(Args);
2020
AnyViewArray<N> stack_args;
2121
Any ret;

python/mlc/_cython/core.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,10 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage):
687687
y = (<PyAny>x)._mlc_any
688688
elif isinstance(x, Str):
689689
y = (<PyAny>(x._pyany))._mlc_any
690+
elif isinstance(x, _OPAQUE_TYPES):
691+
x = _pyany_from_opaque(x)
692+
y = (<PyAny>x)._mlc_any
693+
temporary_storage.append(x)
690694
elif isinstance(x, bool):
691695
y = _MLCAnyBool(<bint>x)
692696
elif isinstance(x, Integral):
@@ -713,10 +717,6 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage):
713717
x = _pyany_from_dlpack(x)
714718
y = (<PyAny>x)._mlc_any
715719
temporary_storage.append(x)
716-
elif isinstance(x, _OPAQUE_TYPES):
717-
x = _pyany_from_opaque(x)
718-
y = (<PyAny>x)._mlc_any
719-
temporary_storage.append(x)
720720
else:
721721
raise TypeError(f"MLC does not recognize type: {type(x)}")
722722
return y

tests/python/test_core_opaque.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
1+
from typing import Any
2+
13
import mlc
24
import pytest
35

46

5-
class MyType:
7+
class MyTypeNotRegistered:
68
def __init__(self, a: int) -> None:
79
self.a = a
810

911

10-
class MyTypeNotRegistered:
12+
class MyType:
1113
def __init__(self, a: int) -> None:
1214
self.a = a
1315

16+
def __call__(self, x: int) -> int:
17+
return x + self.a
18+
1419

1520
mlc.Opaque.register(MyType)
1621

1722

23+
@mlc.dataclasses.py_class(structure="bind")
24+
class Wrapper(mlc.dataclasses.PyClass):
25+
field: Any = mlc.dataclasses.field(structure="nobind")
26+
27+
1828
def test_opaque_init() -> None:
1929
a = MyType(a=10)
2030
opaque = mlc.Opaque(a)
@@ -47,3 +57,40 @@ def test_opaque_ffi_error() -> None:
4757
str(e.value)
4858
== "MLC does not recognize type: <class 'test_core_opaque.MyTypeNotRegistered'>"
4959
)
60+
61+
62+
def test_opaque_dataclass() -> None:
63+
a = MyType(a=10)
64+
wrapper = Wrapper(field=a)
65+
assert isinstance(wrapper.field, MyType)
66+
assert wrapper.field.a == 10
67+
68+
69+
@mlc.Func.register("Opaque.eq_s.test_core_opaque.MyType")
70+
def _eq_s_MyType(a: MyType, b: MyType) -> bool:
71+
return isinstance(a, MyType) and isinstance(b, MyType) and a.a == b.a
72+
73+
74+
@mlc.Func.register("Opaque.hash_s.test_core_opaque.MyType")
75+
def _hash_s_MyType(a: MyType) -> int:
76+
assert isinstance(a, MyType)
77+
return hash((MyType, a.a))
78+
79+
80+
def test_opaque_dataclass_eq_s() -> None:
81+
a1 = Wrapper(field=MyType(a=10))
82+
a2 = Wrapper(field=MyType(a=10))
83+
a1.eq_s(a2, assert_mode=True)
84+
85+
86+
def test_opaque_dataclass_eq_s_fail() -> None:
87+
a1 = Wrapper(field=MyType(a=10))
88+
a2 = Wrapper(field=MyType(a=20))
89+
with pytest.raises(ValueError) as exc_info:
90+
a1.eq_s(a2, assert_mode=True)
91+
assert str(exc_info.value).startswith("Structural equality check failed at {root}.field")
92+
93+
94+
def test_opaque_dataclass_hash_s() -> None:
95+
a1 = Wrapper(field=MyType(a=10))
96+
assert isinstance(a1.hash_s(), int)

0 commit comments

Comments
 (0)