Skip to content

Commit d41c25f

Browse files
authored
Merge pull request #2196 from Shaikh-Ubaid/fix_ccall_for_cpython
Support ccall() for symengine and other libs
2 parents d559090 + f6eab9c commit d41c25f

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

integration_tests/symbolics_07.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from lpython import ccall
1+
from lpython import ccall, CPtr
2+
import os
23

3-
@ccall(header="symengine/cwrapper.h")
4+
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
45
def basic_new_heap() -> CPtr:
56
pass
67

7-
@ccall(header="symengine/cwrapper.h")
8+
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
89
def basic_const_pi(x: CPtr) -> None:
910
pass
1011

11-
@ccall(header="symengine/cwrapper.h")
12+
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
1213
def basic_str(x: CPtr) -> str:
1314
pass
1415

src/runtime/lpython/lpython.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ class CTypes:
335335
A wrapper class for interfacing C via ctypes.
336336
"""
337337

338-
def __init__(self, f):
338+
def __init__(self, f, py_mod = None, py_mod_path = None):
339339
def get_rtlib_dir():
340340
current_dir = os.path.dirname(os.path.abspath(__file__))
341341
return os.path.join(current_dir, "..")
@@ -349,17 +349,20 @@ def get_lib_name(name):
349349
else:
350350
raise NotImplementedError("Platform not implemented")
351351
def get_crtlib_path():
352-
py_mod = os.environ.get("LPYTHON_PY_MOD_NAME", "")
352+
nonlocal py_mod, py_mod_path
353+
if py_mod is None:
354+
py_mod = os.environ.get("LPYTHON_PY_MOD_NAME", "")
353355
if py_mod == "":
354356
return os.path.join(get_rtlib_dir(),
355357
get_lib_name("lpython_runtime"))
356358
else:
357-
py_mod_path = os.environ["LPYTHON_PY_MOD_PATH"]
359+
if py_mod_path is None:
360+
py_mod_path = os.environ["LPYTHON_PY_MOD_PATH"]
358361
return os.path.join(py_mod_path, get_lib_name(py_mod))
359362
self.name = f.__name__
360363
self.args = f.__code__.co_varnames
361364
self.annotations = f.__annotations__
362-
if "LPYTHON_PY_MOD_NAME" in os.environ:
365+
if ("LPYTHON_PY_MOD_NAME" in os.environ) or (py_mod is not None):
363366
crtlib = get_crtlib_path()
364367
self.library = ctypes.CDLL(crtlib)
365368
self.cf = self.library[self.name]
@@ -388,7 +391,10 @@ def __call__(self, *args, **kwargs):
388391
new_args.append(arg.ctypes.data_as(ctypes.POINTER(convert_numpy_dtype_to_ctype(arg.dtype))))
389392
else:
390393
new_args.append(arg)
391-
return self.cf(*new_args)
394+
res = self.cf(*new_args)
395+
if self.cf.restype == ctypes.c_char_p:
396+
res = res.decode("utf-8")
397+
return res
392398

393399
def convert_to_ctypes_Union(f):
394400
fields = []
@@ -465,10 +471,14 @@ def __init__(self, *args):
465471

466472
return ctypes_Structure
467473

468-
def ccall(f):
469-
if isclass(f) and issubclass(f, Union):
470-
return f
471-
return CTypes(f)
474+
def ccall(f=None, header=None, c_shared_lib=None, c_shared_lib_path=None):
475+
def wrap(func):
476+
if not isclass(func) or not issubclass(func, Union):
477+
func = CTypes(func, c_shared_lib, c_shared_lib_path)
478+
return func
479+
if f:
480+
return wrap(f)
481+
return wrap
472482

473483
def pythoncall(*args, **kwargs):
474484
def inner(fn):

0 commit comments

Comments
 (0)