diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index ca3060f..ad41ed7 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -6,8 +6,17 @@ from libcpp.vector cimport vector from libc.stdint cimport int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t from libc.stdlib cimport malloc, free from numbers import Integral, Number -from cpython cimport Py_DECREF, Py_INCREF, PyCapsule_IsValid, PyCapsule_GetPointer, PyCapsule_SetName, PyCapsule_New from . import base +from cpython.pycapsule cimport ( + PyCapsule_IsValid, + PyCapsule_GetPointer, + PyCapsule_SetName, + PyCapsule_New, +) + +cdef extern from "Python.h": + void Py_IncRef(object) + void Py_DecRef(object) Ptr = base.Ptr PyCode_NewEmpty = ctypes.pythonapi.PyCode_NewEmpty @@ -417,21 +426,17 @@ cdef class PyAny: raise e.with_traceback(None) return _any_c2py_no_inc_ref(c_ret) -cdef class Str(str): - cdef MLCAny _mlc_any - __slots__ = () - - def __cinit__(self): - self._mlc_any = _MLCAnyNone() - def __init__(self, value): - cdef str value_unicode = self - cdef bytes value_c = str_py2c(value_unicode) - self._mlc_any = _MLCAnyRawStr(value_c) - _check_error(_C_AnyInplaceViewToOwned(&self._mlc_any)) +class Str(str): + __slots__ = ("_pyany",) - def __dealloc__(self): - _check_error(_C_AnyDecRef(&self._mlc_any)) + def __new__(cls, value: str): + cdef PyAny pyany = PyAny() + self = super().__new__(cls, value) + self._pyany = pyany + pyany._mlc_any = _MLCAnyRawStr(str_py2c(value)) + _check_error(_C_AnyInplaceViewToOwned(&pyany._mlc_any)) + return self def __reduce__(self): return (Str, (str(self),)) @@ -541,7 +546,7 @@ cdef inline object _any_c2py_no_inc_ref(const MLCAny x): cdef int32_t type_index = x.type_index cdef MLCStr* mlc_str = NULL cdef PyAny any_ret - cdef Str str_ret + cdef object str_ret if type_index == kMLCNone: return None elif type_index == kMLCBool: @@ -556,8 +561,10 @@ cdef inline object _any_c2py_no_inc_ref(const MLCAny x): return str_c2py(x.v.v_str) elif type_index == kMLCStr: mlc_str = (x.v.v_obj) + any_ret = PyAny() + any_ret._mlc_any = x str_ret = Str.__new__(Str, str_c2py(mlc_str.data[:mlc_str.length])) - str_ret._mlc_any = x + str_ret._pyany = any_ret return str_ret elif type_index == kMLCOpaque: return (((x.v.v_obj)).handle) @@ -572,7 +579,7 @@ cdef inline object _any_c2py_inc_ref(MLCAny x): cdef int32_t type_index = x.type_index cdef MLCStr* mlc_str = NULL cdef PyAny any_ret - cdef Str str_ret + cdef object str_ret if type_index == kMLCNone: return None elif type_index == kMLCBool: @@ -587,8 +594,10 @@ cdef inline object _any_c2py_inc_ref(MLCAny x): return str_c2py(x.v.v_str) elif type_index == kMLCStr: mlc_str = (x.v.v_obj) + any_ret = PyAny() + any_ret._mlc_any = x str_ret = Str.__new__(Str, str_c2py(mlc_str.data[:mlc_str.length])) - str_ret._mlc_any = x + str_ret._pyany = any_ret _check_error(_C_AnyIncRef(&x)) return str_ret elif type_index == kMLCOpaque: @@ -624,11 +633,11 @@ cdef inline PyAny _pyany_from_opaque(object x): args[0] = _MLCAnyPtr((x)) args[1] = _MLCAnyPtr((_pyobj_deleter)) args[2] = _MLCAnyRawStr(type_name) - Py_INCREF(x) + Py_IncRef(x) try: _func_call_impl_with_c_args(_OPAQUE_INIT, 3, args, &ret._mlc_any) except: # no-cython-lint - Py_DECREF(x) + Py_DecRef(x) raise return ret @@ -677,7 +686,7 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage): elif isinstance(x, PyAny): y = (x)._mlc_any elif isinstance(x, Str): - y = (x)._mlc_any + y = ((x._pyany))._mlc_any elif isinstance(x, bool): y = _MLCAnyBool(x) elif isinstance(x, Integral): @@ -729,7 +738,7 @@ cdef inline MLCAny _any_py2c_dict(tuple x, list temporary_storage): cdef void _pyobj_deleter(void* handle) noexcept nogil: with gil: try: - Py_DECREF((handle)) + Py_DecRef((handle)) except Exception as exception: # TODO(@junrushao): Will need to handle exceptions more gracefully print(f"Error in _pyobj_deleter: {exception}") @@ -767,7 +776,7 @@ cdef inline int32_t _func_safe_call_impl( cdef inline PyAny _pyany_from_func(object py_func): cdef PyAny ret = PyAny() - Py_INCREF(py_func) + Py_IncRef(py_func) _check_error(_C_FuncCreate((py_func), _pyobj_deleter, _func_safe_call, &ret._mlc_any)) return ret