Skip to content

Handle windows conda library directories in default blas #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2744,7 +2744,9 @@ def get_cxx_library_dirs():
[pathlib.Path(p).resolve() for p in line[len("libraries: =") :].split(":")]
for line in stdout.decode(sys.stdout.encoding).splitlines()
if line.startswith("libraries: =")
][0]
]
if len(maybe_lib_dirs) > 0:
maybe_lib_dirs = maybe_lib_dirs[0]
return [str(d) for d in maybe_lib_dirs if d.exists() and d.is_dir()]

def check_libs(
Expand Down Expand Up @@ -2793,6 +2795,13 @@ def check_libs(

cxx_library_dirs = get_cxx_library_dirs()
searched_library_dirs = cxx_library_dirs + _std_lib_dirs
if sys.platform == "win32":
# Conda on Windows saves MKL libraries under CONDA_PREFIX\Library\bin
# From the conda manual (https://docs.conda.io/projects/conda-build/en/stable/user-guide/environment-variables.html)
# it seems like conda could also save some libraries into the CONDA_PREFIX\Library\lib
# directory. We will include both in our searched library dirs
searched_library_dirs.append(os.path.join(sys.prefix, "Library", "bin"))
searched_library_dirs.append(os.path.join(sys.prefix, "Library", "lib"))
all_libs = [
l
for path in [
Expand Down
45 changes: 45 additions & 0 deletions tests/link/c/test_cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,51 @@ def test_default_blas_ldflags_no_cxx():
assert default_blas_ldflags() == ""


@pytest.fixture()
def windows_conda_libs(blas_libs):
libtemplate = "{lib}.dll"
libraries = []
with tempfile.TemporaryDirectory() as d:
subdir = os.path.join(d, "Library", "bin")
os.makedirs(subdir, exist_ok=True)
flags = f'-L"{subdir}"'
for lib in blas_libs:
lib_path = os.path.join(subdir, libtemplate.format(lib=lib))
with open(lib_path, "wb") as f:
f.write(b"1")
libraries.append(lib_path)
flags += f" -l{lib}"
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
flags += " -fopenmp"
if len(blas_libs) == 0:
flags = ""
yield d, flags


@patch("pytensor.link.c.cmodule.std_lib_dirs", return_value=[])
@patch("pytensor.link.c.cmodule.check_mkl_openmp", return_value=None)
def test_default_blas_ldflags_conda_windows(
mock_std_lib_dirs, mock_check_mkl_openmp, windows_conda_libs
):
mock_sys_prefix, expected_blas_ldflags = windows_conda_libs
mock_process = MagicMock()
mock_process.communicate = lambda *args, **kwargs: (b"", b"")
mock_process.returncode = 0
with patch("sys.platform", "win32"):
with patch("sys.prefix", mock_sys_prefix):
with patch(
"pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process
):
with patch.object(
pytensor.link.c.cmodule.GCC_compiler,
"try_compile_tmp",
return_value=(True, True),
):
assert set(default_blas_ldflags().split(" ")) == set(
expected_blas_ldflags.split(" ")
)


@patch(
"os.listdir", return_value=["mkl_core.1.dll", "mkl_rt.1.0.dll", "mkl_rt.1.1.lib"]
)
Expand Down
13 changes: 7 additions & 6 deletions tests/link/c/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def get_hash(modname, seed=None):
cmd_line,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stderr=subprocess.PIPE,
env=env,
)
out, err = p.communicate()
return out, err
return out, err, p.returncode


def test_ExternalCOp_c_code_cache_version():
Expand All @@ -222,10 +222,11 @@ def test_ExternalCOp_c_code_cache_version():
tmp.seek(0)
# modname = os.path.splitext(tmp.name)[0]
modname = tmp.name
out_1, err = get_hash(modname, seed=428)
assert err is None
out_2, err = get_hash(modname, seed=3849)
assert err is None
out_1, err1, returncode1 = get_hash(modname, seed=428)
out_2, err2, returncode2 = get_hash(modname, seed=3849)
assert returncode1 == 0
assert returncode2 == 0
assert err1 == err2

hash_1, msg, _ = out_1.decode().split("\n")
assert msg == "__success__"
Expand Down
31 changes: 17 additions & 14 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,29 +272,32 @@ def test_debugprint():
print_view_map=True,
)
s = s.getvalue()
Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv"
exp_res = dedent(
r"""
Composite{(i2 + (i0 - i1))} 4
├─ ExpandDims{axis=0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1
│ │ └─ Shape_i{0} 0
│ │ └─ B
│ ├─ 1.0
│ ├─ B
│ ├─ <Vector(float64, shape=(?,))>
│ └─ 0.0
├─ D
└─ A
├─ ExpandDims{axis=0} v={0: [0]} 3
"""
f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2"
r"""
│ ├─ AllocEmpty{dtype='float64'} 1
│ │ └─ Shape_i{0} 0
│ │ └─ B
│ ├─ 1.0
│ ├─ B
│ ├─ <Vector(float64, shape=(?,))>
│ └─ 0.0
├─ D
└─ A

Inner graphs:

Composite{(i2 + (i0 - i1))}
← add 'o0'
← add 'o0'
├─ i2
└─ sub
├─ i0
└─ i1
├─ i0
└─ i1
"""
).lstrip()

Expand Down