Skip to content

Commit 258d5f1

Browse files
lucianopazbrandonwillard
authored andcommitted
Patch ldflags and libs in GCC_compile under windows
1 parent b83398e commit 258d5f1

File tree

2 files changed

+108
-4
lines changed

2 files changed

+108
-4
lines changed

aesara/link/c/cmodule.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import stat
1515
import subprocess
1616
import sys
17+
import sysconfig
1718
import tempfile
1819
import textwrap
1920
import time
@@ -1678,7 +1679,7 @@ def std_lib_dirs_and_libs() -> Optional[Tuple[List[str], ...]]:
16781679
# Obtain the library name from the Python version instead of the
16791680
# installation directory, in case the user defined a custom
16801681
# installation directory.
1681-
python_version = distutils.sysconfig.get_python_version()
1682+
python_version = sysconfig.get_python_version()
16821683
libname = "python" + python_version.replace(".", "")
16831684
# Also add directory containing the Python library to the library
16841685
# directories.
@@ -2381,7 +2382,13 @@ def try_compile_tmp(
23812382
comp_args=True,
23822383
):
23832384
return cls._try_compile_tmp(
2384-
src_code, tmp_prefix, flags, try_run, output, config.cxx, comp_args
2385+
src_code,
2386+
tmp_prefix,
2387+
cls.patch_ldflags(flags),
2388+
try_run,
2389+
output,
2390+
config.cxx,
2391+
comp_args,
23852392
)
23862393

23872394
@classmethod
@@ -2395,9 +2402,58 @@ def try_flags(
23952402
comp_args=True,
23962403
):
23972404
return cls._try_flags(
2398-
flag_list, preamble, body, try_run, output, config.cxx, comp_args
2405+
cls.patch_ldflags(flag_list),
2406+
preamble,
2407+
body,
2408+
try_run,
2409+
output,
2410+
config.cxx,
2411+
comp_args,
23992412
)
24002413

2414+
@staticmethod
2415+
def patch_ldflags(flag_list: List[str]) -> List[str]:
2416+
lib_dirs = [flag[2:].lstrip() for flag in flag_list if flag.startswith("-L")]
2417+
flag_idxs: List[int] = []
2418+
libs: List[str] = []
2419+
for i, flag in enumerate(flag_list):
2420+
if flag.startswith("-l"):
2421+
flag_idxs.append(i)
2422+
libs.append(flag[2:].lstrip())
2423+
if not libs:
2424+
return flag_list
2425+
libs = GCC_compiler.linking_patch(lib_dirs, libs)
2426+
for flag_idx, lib in zip(flag_idxs, libs):
2427+
flag_list[flag_idx] = lib
2428+
return flag_list
2429+
2430+
@staticmethod
2431+
def linking_patch(lib_dirs: List[str], libs: List[str]) -> List[str]:
2432+
if sys.platform != "win32":
2433+
return [f"-l{l}" for l in libs]
2434+
2435+
def sort_key(lib): # type: ignore
2436+
name, *numbers, extension = lib.split(".")
2437+
return (extension == "dll", tuple(map(int, numbers)))
2438+
2439+
patched_lib_ldflags = []
2440+
for lib in libs:
2441+
ldflag = f"-l{lib}"
2442+
for lib_dir in lib_dirs:
2443+
lib_dir = lib_dir.strip('"')
2444+
windows_styled_libs = [
2445+
fname
2446+
for fname in os.listdir(lib_dir)
2447+
if not (os.path.isdir(os.path.join(lib_dir, fname)))
2448+
and fname.split(".")[0] == lib
2449+
and fname.split(".")[-1] in ["dll", "lib"]
2450+
]
2451+
if windows_styled_libs:
2452+
selected_lib = sorted(windows_styled_libs, key=sort_key)[-1]
2453+
ldflag = f'"{os.path.join(lib_dir, selected_lib)}"'
2454+
patched_lib_ldflags.append(ldflag)
2455+
return patched_lib_ldflags
2456+
24012457
@staticmethod
24022458
def compile_str(
24032459
module_name,
@@ -2509,7 +2565,7 @@ def compile_str(
25092565
cmd.append("-fvisibility=hidden")
25102566
cmd.extend(["-o", f"{path_wrapper}{lib_filename}{path_wrapper}"])
25112567
cmd.append(f"{path_wrapper}{cppfilename}{path_wrapper}")
2512-
cmd.extend([f"-l{l}" for l in libs])
2568+
cmd.extend(GCC_compiler.linking_patch(lib_dirs, libs))
25132569
# print >> sys.stderr, 'COMPILING W CMD', cmd
25142570
_logger.debug(f"Running cmd: {' '.join(cmd)}")
25152571

tests/link/c/test_cmodule.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
deterministic based on the input type and the op.
66
"""
77
import logging
8+
import os
89
import tempfile
910
from unittest.mock import patch
1011

@@ -91,3 +92,50 @@ def test_default_blas_ldflags(sys_mock, try_blas_flag_mock, caplog):
9192
default_blas_ldflags()
9293

9394
assert caplog.text == ""
95+
96+
97+
@patch(
98+
"os.listdir", return_value=["mkl_core.1.dll", "mkl_rt.1.0.dll", "mkl_rt.1.1.lib"]
99+
)
100+
@patch("sys.platform", "win32")
101+
def test_patch_ldflags(listdir_mock):
102+
mkl_path = "some_path"
103+
flag_list = ["-lm", "-lopenblas", f"-L {mkl_path}", "-l mkl_core", "-lmkl_rt"]
104+
assert GCC_compiler.patch_ldflags(flag_list) == [
105+
"-lm",
106+
"-lopenblas",
107+
f"-L {mkl_path}",
108+
'"' + os.path.join(mkl_path, "mkl_core.1.dll") + '"',
109+
'"' + os.path.join(mkl_path, "mkl_rt.1.0.dll") + '"',
110+
]
111+
112+
113+
@patch(
114+
"os.listdir",
115+
return_value=[
116+
"libopenblas.so",
117+
"libm.a",
118+
"mkl_core.1.dll",
119+
"mkl_rt.1.0.dll",
120+
"mkl_rt.1.1.dll",
121+
],
122+
)
123+
@pytest.mark.parametrize("platform", ["win32", "linux", "darwin"])
124+
def test_linking_patch(listdir_mock, platform):
125+
libs = ["openblas", "m", "mkl_core", "mkl_rt"]
126+
lib_dirs = ['"mock_dir"']
127+
with patch("sys.platform", platform):
128+
if platform == "win32":
129+
assert GCC_compiler.linking_patch(lib_dirs, libs) == [
130+
"-lopenblas",
131+
"-lm",
132+
'"' + os.path.join(lib_dirs[0].strip('"'), "mkl_core.1.dll") + '"',
133+
'"' + os.path.join(lib_dirs[0].strip('"'), "mkl_rt.1.1.dll") + '"',
134+
]
135+
else:
136+
GCC_compiler.linking_patch(lib_dirs, libs) == [
137+
"-lopenblas",
138+
"-lm",
139+
"-lmkl_core",
140+
"-lmkl_rt",
141+
]

0 commit comments

Comments
 (0)