Skip to content

Commit b5b13aa

Browse files
committed
Experimental support for pip dependencies
In the absence of an external interface to pip's resolver (see e.g. pypa/pip#7819), this uses Poetry's resolution logic to convert pip requirements from environment.yaml to either transitive dependencies (in the case of env output) or direct references (in the case of explicit output). In explicit mode these are emitted as comment lines that `conda-lock install` can unpack and pass to `pip install` inside of the target environment.
1 parent a9724ae commit b5b13aa

File tree

6 files changed

+480
-17
lines changed

6 files changed

+480
-17
lines changed

conda_lock/conda_lock.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from conda_lock.common import read_file, read_json, write_file
4141
from conda_lock.errors import PlatformValidationError
42+
from conda_lock.pypi_solver import PipRequirement, solve_pypi
4243
from conda_lock.src_parser import LockSpecification
4344
from conda_lock.src_parser.environment_yaml import parse_environment_file
4445
from conda_lock.src_parser.meta_yaml import parse_meta_yaml_file
@@ -264,15 +265,20 @@ def do_conda_install(conda: PathLike, prefix: str, name: str, file: str) -> None
264265
*([] if kind == "env" else ["--yes"]),
265266
]
266267

268+
common_args = []
267269
if prefix:
268-
args.append("--prefix")
269-
args.append(prefix)
270+
common_args.append("--prefix")
271+
common_args.append(prefix)
270272
if name:
271-
args.append("--name")
272-
args.append(name)
273+
common_args.append("--name")
274+
common_args.append(name)
273275
conda_flags = os.environ.get("CONDA_FLAGS")
274276
if conda_flags:
275-
args.extend(shlex.split(conda_flags))
277+
common_args.extend(shlex.split(conda_flags))
278+
279+
args.extend(common_args)
280+
281+
assert len(common_args) == 2
276282

277283
logging.debug("$MAMBA_ROOT_PREFIX: %s", os.environ.get("MAMBA_ROOT_PREFIX"))
278284

@@ -297,6 +303,47 @@ def do_conda_install(conda: PathLike, prefix: str, name: str, file: str) -> None
297303
)
298304
sys.exit(1)
299305

306+
if kind == "explicit":
307+
with open(file) as explicit_env:
308+
pip_requirements = [
309+
line.split("# pip ")[1]
310+
for line in explicit_env
311+
if line.startswith("# pip ")
312+
]
313+
if not pip_requirements:
314+
return
315+
316+
with tempfile.NamedTemporaryFile() as tf:
317+
write_file("\n".join(pip_requirements), tf.name)
318+
pip_proc = subprocess.run(
319+
[
320+
str(conda),
321+
"run",
322+
]
323+
+ common_args
324+
+ [
325+
"pip",
326+
"install",
327+
"--no-deps",
328+
"-r",
329+
tf.name,
330+
]
331+
)
332+
333+
if pip_proc.stdout:
334+
for line in pip_proc.stdout.decode().split("\n"):
335+
logging.info(line)
336+
337+
if pip_proc.stderr:
338+
for line in pip_proc.stderr.decode().split("\n"):
339+
logging.error(line.rstrip())
340+
341+
if pip_proc.returncode != 0:
342+
print(
343+
f"Could not perform pip install using {file} lock file into {name or prefix}"
344+
)
345+
sys.exit(1)
346+
300347

301348
def search_for_md5s(
302349
conda: PathLike, package_specs: List[dict], platform: str, channels: Sequence[str]
@@ -539,12 +586,39 @@ def create_lockfile_from_spec(
539586
)
540587
logging.debug("dry_run_install:\n%s", dry_run_install)
541588

589+
if spec.pip_specs:
590+
python_version: Optional[str] = None
591+
locked_packages = []
592+
for package in (
593+
dry_run_install["actions"]["FETCH"] + dry_run_install["actions"]["LINK"]
594+
):
595+
if package["name"] == "python":
596+
python_version = package["version"]
597+
elif not package["name"].startswith("__"):
598+
locked_packages.append((package["name"], package["version"]))
599+
if python_version is None:
600+
raise ValueError("Got pip specs without Python")
601+
pip = solve_pypi(
602+
spec.pip_specs,
603+
conda_installed=locked_packages,
604+
python_version=python_version,
605+
platform=spec.platform,
606+
)
607+
else:
608+
pip = []
609+
542610
lockfile_contents = [
543611
"# Generated by conda-lock.",
544612
f"# platform: {spec.platform}",
545613
f"# input_hash: {spec.input_hash()}\n",
546614
]
547615

616+
def format_pip_requirement(spec: PipRequirement) -> str:
617+
if "url" in spec:
618+
return f'{spec["name"]} @ {spec["url"]}'
619+
else:
620+
return f'{spec["name"]} === {spec["version"]}'
621+
548622
if kind == "env":
549623
link_actions = dry_run_install["actions"]["LINK"]
550624
lockfile_contents.extend(
@@ -560,6 +634,10 @@ def create_lockfile_from_spec(
560634
),
561635
]
562636
)
637+
if pip:
638+
lockfile_contents.extend(
639+
[" - pip:", *(f" - {format_pip_requirement(pkg)}" for pkg in pip)]
640+
)
563641
elif kind == "explicit":
564642
lockfile_contents.append("@EXPLICIT\n")
565643

@@ -611,6 +689,18 @@ def sanitize_lockfile_line(line):
611689
return line
612690

613691
lockfile_contents = [sanitize_lockfile_line(line) for line in lockfile_contents]
692+
693+
# emit an explicit requirements.txt, prefixed with '# pip '
694+
for pkg in pip:
695+
lines = [format_pip_requirement(pkg)] + [
696+
f" --hash={hash}" for hash in pkg["hashes"]
697+
]
698+
lockfile_contents.extend(
699+
[
700+
f"# pip {line}"
701+
for line in [line + " \\" for line in lines[:-1]] + [lines[-1]]
702+
]
703+
)
614704
else:
615705
raise ValueError(f"Unrecognised lock kind {kind}.")
616706

@@ -670,6 +760,12 @@ def aggregate_lock_specs(lock_specs: List[LockSpecification]) -> LockSpecificati
670760
set(chain.from_iterable([lock_spec.specs for lock_spec in lock_specs]))
671761
)
672762

763+
pip_specs = list(
764+
set(
765+
chain.from_iterable([lock_spec.pip_specs or [] for lock_spec in lock_specs])
766+
)
767+
)
768+
673769
# pick the first non-empty channel
674770
channels: List[str] = next(
675771
(lock_spec.channels for lock_spec in lock_specs if lock_spec.channels), []
@@ -680,7 +776,9 @@ def aggregate_lock_specs(lock_specs: List[LockSpecification]) -> LockSpecificati
680776
(lock_spec.platform for lock_spec in lock_specs if lock_spec.platform), ""
681777
)
682778

683-
return LockSpecification(specs=specs, channels=channels, platform=platform)
779+
return LockSpecification(
780+
specs=specs, channels=channels, platform=platform, pip_specs=pip_specs
781+
)
684782

685783

686784
def _ensureconda(

conda_lock/pypi_solver.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import re
2+
import sys
3+
4+
from pathlib import Path
5+
from typing import Optional, TypedDict
6+
from urllib.parse import urldefrag
7+
8+
from clikit.api.io.flags import VERY_VERBOSE
9+
from clikit.io import ConsoleIO
10+
from packaging.tags import compatible_tags, cpython_tags
11+
from poetry.core.packages import Dependency, Package, ProjectPackage, URLDependency
12+
from poetry.installation.chooser import Chooser
13+
from poetry.installation.operations import Install
14+
from poetry.installation.operations.uninstall import Uninstall
15+
from poetry.puzzle import Solver
16+
from poetry.repositories.pool import Pool
17+
from poetry.repositories.pypi_repository import PyPiRepository
18+
from poetry.repositories.repository import Repository
19+
from poetry.utils.env import Env
20+
21+
from conda_lock.src_parser.pyproject_toml import get_lookup as get_forward_lookup
22+
23+
24+
class PlatformEnv(Env):
25+
def __init__(self, python_version, platform):
26+
super().__init__(path=Path(sys.prefix))
27+
if platform == "linux-64":
28+
# FIXME: in principle these depend on the glibc in the conda env
29+
self._platforms = ["manylinux_2_17_x86_64", "manylinux2014_x86_64"]
30+
else:
31+
raise ValueError(f"Unsupported platform '{platform}'")
32+
self._python_version = tuple(map(int, python_version.split(".")))
33+
34+
def get_supported_tags(self):
35+
"""
36+
Mimic the output of packaging.tags.sys_tags() on the given platform
37+
"""
38+
return list(
39+
cpython_tags(python_version=self._python_version, platforms=self._platforms)
40+
) + list(
41+
compatible_tags(
42+
python_version=self._python_version, platforms=self._platforms
43+
)
44+
)
45+
46+
47+
class PipRequirement(TypedDict):
48+
name: str
49+
version: Optional[str]
50+
url: str
51+
hashes: list[str]
52+
53+
54+
REQUIREMENT_PATTERN = re.compile(
55+
r"""
56+
^
57+
(?P<name>[a-zA-Z0-9_-]+) # package name
58+
(?:\[(?P<extras>(?:\s?[a-zA-Z0-9_-]+(?:\s?\,\s?)?)+)\])? # extras
59+
(?:
60+
(?: # a direct reference
61+
\s?@\s?(?P<url>.*)
62+
)
63+
|
64+
(?: # one or more PEP440 version specifiers
65+
\s?(?P<constraint>
66+
(?:\s?
67+
(?:
68+
(?:=|[><~=!])?=
69+
|
70+
[<>]
71+
)
72+
\s?
73+
(?:
74+
[A-Za-z0-9\.-_\*]+ # a version tuple, e.g. x.y.z
75+
(?:-[A-Za-z]+(?:\.[0-9]+)?)? # a post-release tag, e.g. -alpha.2
76+
(?:\s?\,\s?)?
77+
)
78+
)+
79+
)
80+
)
81+
)?
82+
$
83+
""",
84+
re.VERBOSE,
85+
)
86+
87+
88+
def parse_pip_requirement(requirement: str) -> Optional[dict[str, str]]:
89+
match = REQUIREMENT_PATTERN.match(requirement)
90+
if not match:
91+
return None
92+
return match.groupdict()
93+
94+
95+
def get_dependency(requirement: str) -> Dependency:
96+
parsed = parse_pip_requirement(requirement)
97+
if parsed is None:
98+
raise ValueError(f"Unknown pip requirement '{requirement}'")
99+
extras = re.split(r"\s?\,\s?", parsed["extras"]) if parsed["extras"] else None
100+
if parsed["url"]:
101+
return URLDependency(name=parsed["name"], url=parsed["url"], extras=extras)
102+
else:
103+
return Dependency(
104+
name=parsed["name"], constraint=parsed["constraint"] or "*", extras=extras
105+
)
106+
107+
108+
PYPI_LOOKUP: Optional[dict] = None
109+
110+
111+
def get_lookup() -> dict:
112+
global PYPI_LOOKUP
113+
if PYPI_LOOKUP is None:
114+
PYPI_LOOKUP = {
115+
record["conda_name"]: record for record in get_forward_lookup().values()
116+
}
117+
return PYPI_LOOKUP
118+
119+
120+
def normalize_conda_name(name: str):
121+
return get_lookup().get(name, {"pypi_name": name})["pypi_name"]
122+
123+
124+
def solve_pypi(
125+
dependencies: list[str],
126+
conda_installed: list[tuple[str, str]],
127+
python_version: str,
128+
platform: str,
129+
verbose: bool = False,
130+
) -> list[PipRequirement]:
131+
dummy_package = ProjectPackage("_dummy_package_", "0.0.0")
132+
dummy_package.python_versions = f"=={python_version}"
133+
for spec in dependencies:
134+
dummy_package.add_dependency(get_dependency(spec))
135+
136+
pypi = PyPiRepository()
137+
pool = Pool(repositories=[pypi])
138+
139+
installed = Repository()
140+
locked = Repository()
141+
142+
python_packages = dict()
143+
for name, version in conda_installed:
144+
pypi_name = normalize_conda_name(name)
145+
# Prefer the Python package when its name collides with the Conda package
146+
# for the underlying library, e.g. python-xxhash (pypi: xxhash) over xxhash
147+
# (pypi: no equivalent)
148+
if pypi_name not in python_packages or pypi_name != name:
149+
python_packages[pypi_name] = version
150+
for name, version in python_packages.items():
151+
for repo in (locked, installed):
152+
repo.add_package(Package(name=name, version=version))
153+
154+
io = ConsoleIO()
155+
if verbose:
156+
io.set_verbosity(VERY_VERBOSE)
157+
s = Solver(
158+
dummy_package,
159+
pool=pool,
160+
installed=installed,
161+
locked=locked,
162+
io=io,
163+
)
164+
result = s.solve(use_latest=dependencies)
165+
166+
chooser = Chooser(pool, env=PlatformEnv(python_version, platform))
167+
168+
# Extract distributions from Poetry package plan, ignoring uninstalls
169+
# (usually: conda package with no pypi equivalent) and skipped ops
170+
# (already installed)
171+
requirements: list[PipRequirement] = []
172+
for op in result:
173+
if not isinstance(op, Uninstall) and not op.skipped:
174+
# Take direct references verbatim
175+
if op.package.source_type == "url":
176+
url, fragment = urldefrag(op.package.source_url)
177+
requirements.append(
178+
{
179+
"name": op.package.name,
180+
"version": None,
181+
"url": url,
182+
"hashes": [fragment.replace("=", ":")],
183+
}
184+
)
185+
# Choose the most specific distribution for the target
186+
else:
187+
link = chooser.choose_for(op.package)
188+
requirements.append(
189+
{
190+
"name": op.package.name,
191+
"version": str(op.package.version),
192+
"url": link.url_without_fragment,
193+
"hashes": [f"{link.hash_name}:{link.hash}"],
194+
}
195+
)
196+
197+
return requirements

conda_lock/src_parser/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@ def __init__(
1212
specs: List[str],
1313
channels: List[str],
1414
platform: str,
15+
pip_specs: Optional[List[str]] = None,
1516
virtual_package_repo: Optional[FakeRepoData] = None,
1617
):
1718
self.specs = specs
1819
self.channels = channels
1920
self.platform = platform
21+
self.pip_specs = pip_specs
2022
self.virtual_package_repo = virtual_package_repo
2123

2224
def input_hash(self) -> str:
2325
data: dict = {
2426
"channels": self.channels,
2527
"platform": self.platform,
2628
"specs": sorted(self.specs),
29+
"pip_specs": sorted(self.pip_specs or []),
2730
}
2831
if self.virtual_package_repo is not None:
2932
vpr_data = self.virtual_package_repo.all_repodata

0 commit comments

Comments
 (0)