Skip to content

Commit a694dd7

Browse files
committed
feat: use temp file for oci auth
This adjusts handling of oci auth passed via environment variable so that we don't subsequently pass the value as a subprocess argument, as this poses a security risk. Instead we write it to a temporary file, which is automatically deleted after use. Signed-off-by: Jon Burdo <[email protected]>
1 parent 3bb5b52 commit a694dd7

File tree

4 files changed

+177
-35
lines changed

4 files changed

+177
-35
lines changed

clients/python/poetry.lock

Lines changed: 22 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

clients/python/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ sphinx-autobuild = ">=2021.3.14,<2025.0.0"
4848
pytest = ">=7.4.2,<9.0.0"
4949
coverage = { extras = ["toml"], version = "^7.3.2" }
5050
pytest-cov = ">=4.1,<7.0"
51+
pytest-mock = ">=3.7.0"
5152
ruff = ">=0.5.2,<0.13.0"
5253
mypy = "^1.7.0"
5354
# atm Ray is only available <3.13, so we will E2E test using Ray in compatible py environments.

clients/python/src/model_registry/utils.py

Lines changed: 110 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import shutil
1010
import tempfile
1111
import threading
12+
from contextlib import AbstractContextManager, contextmanager
1213
from dataclasses import asdict, dataclass
1314
from pathlib import Path
1415
from subprocess import CalledProcessError
15-
from typing import TYPE_CHECKING, Callable, Protocol, TypeVar
16+
from typing import TYPE_CHECKING, Callable, Protocol, TextIO, TypeVar
1617

1718
from typing_extensions import Literal, overload
1819

@@ -223,15 +224,17 @@ def _backend_specific_params(
223224
# Determine backend
224225
if backend == "skopeo":
225226
prefix = "--src" if type == "pull" else "--dest"
227+
auth_suffix = "authfile"
226228
elif backend == "oras":
227229
prefix = "--from" if type == "pull" else "--to"
230+
auth_suffix = "registry-config"
231+
else:
232+
msg = f"invalid backend: {backend!r}"
233+
raise ValueError(msg)
228234

229235
# Actual param specifications
230-
if username := kwargs.pop("username", None):
231-
kwargs[f"{prefix}-username"] = username
232-
233-
if password := kwargs.pop("password", None):
234-
kwargs[f"{prefix}-password"] = password
236+
if authfile := kwargs.pop("authfile", None):
237+
kwargs[f"{prefix}-{auth_suffix}"] = authfile
235238

236239
return kwargs
237240

@@ -344,16 +347,14 @@ def save_to_oci_registry( # noqa: C901 ( complex args >8 )
344347
raise StoreError(msg) from e
345348

346349
# Check for OCI Auth Env and a default
347-
auth: str = None
350+
auth: str | None = None
348351
if oci_auth_env_var:
349-
env_value = _validate_env_var(oci_auth_env_var)
350-
auth = _extract_auth_json(env_value)
351-
else:
352-
try:
353-
env_value = _validate_env_var(".dockerconfigjson")
354-
auth = _extract_auth_json(env_value)
355-
except ValueError:
356-
pass
352+
auth = _validate_env_var(oci_auth_env_var)
353+
elif ".dockerconfigjson" in os.environ:
354+
auth = os.environ[".dockerconfigjson"] # noqa: SIM112
355+
356+
elif oci_username and oci_password:
357+
auth = json.dumps(create_auth_object(oci_ref, oci_username, oci_password))
357358

358359
# If a custom backend is provided, use it, else fetch the backend out of the registry
359360
if custom_oci_backend:
@@ -374,30 +375,48 @@ def save_to_oci_registry( # noqa: C901 ( complex args >8 )
374375
dest_dir_cleanup = True
375376
local_image_path = Path(dest_dir)
376377

377-
# Set params
378378
params = {}
379-
380-
# User/pass
381-
if auth:
382-
usr_pass = auth.split(":")
383-
params["username"] = usr_pass[0]
384-
params["password"] = usr_pass[-1]
385-
elif oci_username and oci_password:
386-
params["username"] = oci_username
387-
params["password"] = oci_password
388-
389-
backend_def.pull(base_image, local_image_path, **params)
390-
# Extract the absolute path from the files found in the path
391-
files = [file[0] for file in _get_files_from_path(model_files_path)]
392-
oci_layers_on_top(local_image_path, files, modelcard)
393-
backend_def.push(local_image_path, oci_ref, **params)
379+
with temp_auth_file(auth) as auth_file:
380+
if auth_file is not None:
381+
params["authfile"] = auth_file.name
382+
backend_def.pull(base_image, local_image_path, **params)
383+
# Extract the absolute path from the files found in the path
384+
files = [file[0] for file in _get_files_from_path(model_files_path)]
385+
oci_layers_on_top(local_image_path, files, modelcard)
386+
backend_def.push(local_image_path, oci_ref, **params)
394387

395388
# Return the OCI URI
396389
if dest_dir_cleanup:
397390
shutil.rmtree(dest_dir)
398391
return f"oci://{oci_ref}"
399392

400393

394+
@overload
395+
def temp_auth_file(auth: str) -> AbstractContextManager[TextIO]:
396+
...
397+
398+
399+
@overload
400+
def temp_auth_file(auth: None) -> AbstractContextManager[None]:
401+
...
402+
403+
404+
@contextmanager
405+
def temp_auth_file(auth: str | None) -> AbstractContextManager[TextIO | None]:
406+
"""Create a temporary auth file with optional auth data.
407+
408+
If auth is None, yields None. Otherwise creates a temporary JSON file
409+
containing the auth string and yields the file handle.
410+
"""
411+
if auth is None:
412+
yield None
413+
else:
414+
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", suffix=".json") as temp_auth_file:
415+
temp_auth_file.write(auth)
416+
temp_auth_file.flush()
417+
yield temp_auth_file
418+
419+
401420
def _s3_creds(
402421
endpoint_url: str | None = None,
403422
access_key_id: str | None = None,
@@ -636,6 +655,66 @@ def _extract_auth_json(auth_data: str) -> str:
636655
raise ValueError(invalid_json_msg) from e
637656

638657

658+
def get_auth_reference(image_path: str) -> str:
659+
"""Parses an arbitrary container image path to extract a valid reference.
660+
661+
for use as a key in a container registry auth.json file.
662+
663+
Examples:
664+
'quay.io/my-org/my-registry:1.0.0' -> 'quay.io/my-org/my-registry'
665+
'my-private-registry:5000/team/app:latest' -> 'my-private-registry:5000/team/app'
666+
'ubuntu' -> 'docker.io'
667+
'ubuntu:22.04' -> 'docker.io'
668+
'my-user/my-app' -> 'docker.io'
669+
'my-user/my-app:v2' -> 'docker.io'
670+
'quay.io/my-org/my-registry@sha256:f1b3f5a2d...' -> 'quay.io/my-org/my-registry'
671+
'localhost:5000/my-local-image' -> 'localhost:5000/my-local-image'
672+
'localhost:5000/my-local-image:test-tag' -> 'localhost:5000/my-local-image'
673+
"""
674+
repo_path = image_path
675+
676+
# Remove digest if it exists
677+
if "@" in repo_path:
678+
repo_path = repo_path.split("@", 1)[0]
679+
680+
# Separate the tag from the repository path.
681+
# The tag is what comes after the last colon, but only if that colon
682+
# is not part of the hostname/port. A colon indicates a tag if it
683+
# appears after the last slash in the path.
684+
last_colon = repo_path.rfind(":")
685+
last_slash = repo_path.rfind("/")
686+
687+
if last_colon > last_slash:
688+
# This is a tag, not a port, so we strip it.
689+
repo_path = repo_path[:last_colon]
690+
691+
# Handle default Docker Hub images (e.g., 'ubuntu', 'user/repo').
692+
# The hostname is the part of the path before the first slash.
693+
first_slash_index = repo_path.find("/")
694+
hostname = repo_path
695+
if first_slash_index != -1:
696+
hostname = repo_path[:first_slash_index]
697+
698+
# If the hostname part doesn't contain a '.' (like quay.io) or a ':' (like localhost:5000),
699+
# it's a short name for an image on Docker Hub.
700+
if "." not in hostname and ":" not in hostname:
701+
return "docker.io"
702+
703+
# For all other images, the full repository path is the reference.
704+
return repo_path
705+
706+
707+
def create_auth_object(oci_ref: str, username: str, password: str) -> dict[str: dict[str, dict[str, str]]]:
708+
"""Create an auth object for container registry authentication.
709+
710+
This object can be encoded as json with json.dumps() producing the
711+
contents for valid authfile.
712+
"""
713+
auth_ref = get_auth_reference(oci_ref)
714+
auth_value = base64.b64encode(f"{username}:{password}".encode()).decode("utf-8")
715+
return {"auths": {auth_ref: {"auth": auth_value}}}
716+
717+
639718
def rand_suffix(size: int = 8) -> str:
640719
"""Generate a random suffix.
641720

clients/python/tests/test_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import json
12
import os
3+
from contextlib import contextmanager
4+
from pathlib import Path
25

36
import pytest
47

@@ -8,6 +11,7 @@
811
_get_files_from_path,
912
s3_uri_from,
1013
save_to_oci_registry,
14+
temp_auth_file,
1115
)
1216

1317

@@ -124,6 +128,46 @@ def test_save_to_oci_registry_with_custom_backend(
124128
assert uri == f"oci://{oci_ref}"
125129

126130

131+
def test_save_to_oci_registry_with_username_password(mocker, tmp_path):
132+
username = "user32"
133+
password = "zi3327"
134+
model_files_path = tmp_path / "model-files"
135+
model_files_path.mkdir()
136+
(model_files_path / "model.bin").touch()
137+
dest_dir = tmp_path / "dest"
138+
139+
temp_auth_file_info = {}
140+
141+
@contextmanager
142+
def temp_auth_file_wrapper(auth):
143+
with temp_auth_file(auth) as f:
144+
temp_auth_file_info["path"] = f.name
145+
temp_auth_file_info["contents"] = Path(f.name).read_text()
146+
yield f
147+
148+
mock_skopeo_pull = mocker.patch("olot.backend.skopeo.skopeo_pull")
149+
mock_skopeo_push = mocker.patch("olot.backend.skopeo.skopeo_push")
150+
mocker.patch("olot.basics.oci_layers_on_top")
151+
mocker.patch("model_registry.utils.temp_auth_file", side_effect=temp_auth_file_wrapper)
152+
153+
save_to_oci_registry(
154+
base_image="busybox",
155+
oci_ref="quay.io/example/example:latest",
156+
model_files_path=model_files_path,
157+
dest_dir=dest_dir,
158+
backend="skopeo",
159+
oci_username=username,
160+
oci_password=password,
161+
)
162+
163+
assert mock_skopeo_pull.call_args.args == ("busybox", dest_dir, ["--src-authfile", mocker.ANY])
164+
assert mock_skopeo_pull.call_args.kwargs == {}
165+
assert mock_skopeo_push.call_args.args == (dest_dir, "quay.io/example/example:latest", ["--dest-authfile", mocker.ANY])
166+
assert mock_skopeo_push.call_args.kwargs == {}
167+
assert json.loads(temp_auth_file_info["contents"]) == {"auths": {"quay.io/example/example": {"auth": "dXNlcjMyOnppMzMyNw=="}}}
168+
assert not Path(temp_auth_file_info["path"]).exists()
169+
170+
127171
@pytest.mark.e2e(type="oci")
128172
def test_save_to_oci_registry_auth_params(
129173
get_temp_dir_with_models,

0 commit comments

Comments
 (0)