diff --git a/integration_tests/symbolics_07.py b/integration_tests/symbolics_07.py index 52b3a3902a..f6d9ea947f 100644 --- a/integration_tests/symbolics_07.py +++ b/integration_tests/symbolics_07.py @@ -1,14 +1,15 @@ -from lpython import ccall +from lpython import ccall, CPtr +import os -@ccall(header="symengine/cwrapper.h") +@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib") def basic_new_heap() -> CPtr: pass -@ccall(header="symengine/cwrapper.h") +@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib") def basic_const_pi(x: CPtr) -> None: pass -@ccall(header="symengine/cwrapper.h") +@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib") def basic_str(x: CPtr) -> str: pass diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 499083cdbc..4ec0faad4a 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -335,7 +335,7 @@ class CTypes: A wrapper class for interfacing C via ctypes. """ - def __init__(self, f): + def __init__(self, f, py_mod = None, py_mod_path = None): def get_rtlib_dir(): current_dir = os.path.dirname(os.path.abspath(__file__)) return os.path.join(current_dir, "..") @@ -349,17 +349,20 @@ def get_lib_name(name): else: raise NotImplementedError("Platform not implemented") def get_crtlib_path(): - py_mod = os.environ.get("LPYTHON_PY_MOD_NAME", "") + nonlocal py_mod, py_mod_path + if py_mod is None: + py_mod = os.environ.get("LPYTHON_PY_MOD_NAME", "") if py_mod == "": return os.path.join(get_rtlib_dir(), get_lib_name("lpython_runtime")) else: - py_mod_path = os.environ["LPYTHON_PY_MOD_PATH"] + if py_mod_path is None: + py_mod_path = os.environ["LPYTHON_PY_MOD_PATH"] return os.path.join(py_mod_path, get_lib_name(py_mod)) self.name = f.__name__ self.args = f.__code__.co_varnames self.annotations = f.__annotations__ - if "LPYTHON_PY_MOD_NAME" in os.environ: + if ("LPYTHON_PY_MOD_NAME" in os.environ) or (py_mod is not None): crtlib = get_crtlib_path() self.library = ctypes.CDLL(crtlib) self.cf = self.library[self.name] @@ -388,7 +391,10 @@ def __call__(self, *args, **kwargs): new_args.append(arg.ctypes.data_as(ctypes.POINTER(convert_numpy_dtype_to_ctype(arg.dtype)))) else: new_args.append(arg) - return self.cf(*new_args) + res = self.cf(*new_args) + if self.cf.restype == ctypes.c_char_p: + res = res.decode("utf-8") + return res def convert_to_ctypes_Union(f): fields = [] @@ -465,10 +471,14 @@ def __init__(self, *args): return ctypes_Structure -def ccall(f): - if isclass(f) and issubclass(f, Union): - return f - return CTypes(f) +def ccall(f=None, header=None, c_shared_lib=None, c_shared_lib_path=None): + def wrap(func): + if not isclass(func) or not issubclass(func, Union): + func = CTypes(func, c_shared_lib, c_shared_lib_path) + return func + if f: + return wrap(f) + return wrap def pythoncall(*args, **kwargs): def inner(fn):