9
9
import shutil
10
10
import tempfile
11
11
import threading
12
+ from contextlib import AbstractContextManager , contextmanager
12
13
from dataclasses import asdict , dataclass
13
14
from pathlib import Path
14
15
from subprocess import CalledProcessError
15
- from typing import TYPE_CHECKING , Callable , Protocol , TypeVar
16
+ from typing import TYPE_CHECKING , Callable , Protocol , TextIO , TypeVar
16
17
17
18
from typing_extensions import Literal , overload
18
19
@@ -223,15 +224,17 @@ def _backend_specific_params(
223
224
# Determine backend
224
225
if backend == "skopeo" :
225
226
prefix = "--src" if type == "pull" else "--dest"
227
+ auth_suffix = "authfile"
226
228
elif backend == "oras" :
227
229
prefix = "--from" if type == "pull" else "--to"
230
+ auth_suffix = "registry-config"
231
+ else :
232
+ msg = f"invalid backend: { backend !r} "
233
+ raise ValueError (msg )
228
234
229
235
# Actual param specifications
230
- if username := kwargs .pop ("username" , None ):
231
- kwargs [f"{ prefix } -username" ] = username
232
-
233
- if password := kwargs .pop ("password" , None ):
234
- kwargs [f"{ prefix } -password" ] = password
236
+ if authfile := kwargs .pop ("authfile" , None ):
237
+ kwargs [f"{ prefix } -{ auth_suffix } " ] = authfile
235
238
236
239
return kwargs
237
240
@@ -344,16 +347,14 @@ def save_to_oci_registry( # noqa: C901 ( complex args >8 )
344
347
raise StoreError (msg ) from e
345
348
346
349
# Check for OCI Auth Env and a default
347
- auth : str = None
350
+ auth : str | None = None
348
351
if oci_auth_env_var :
349
- env_value = _validate_env_var (oci_auth_env_var )
350
- auth = _extract_auth_json (env_value )
351
- else :
352
- try :
353
- env_value = _validate_env_var (".dockerconfigjson" )
354
- auth = _extract_auth_json (env_value )
355
- except ValueError :
356
- pass
352
+ auth = _validate_env_var (oci_auth_env_var )
353
+ elif ".dockerconfigjson" in os .environ :
354
+ auth = os .environ [".dockerconfigjson" ] # noqa: SIM112
355
+
356
+ elif oci_username and oci_password :
357
+ auth = json .dumps (create_auth_object (oci_ref , oci_username , oci_password ))
357
358
358
359
# If a custom backend is provided, use it, else fetch the backend out of the registry
359
360
if custom_oci_backend :
@@ -374,30 +375,48 @@ def save_to_oci_registry( # noqa: C901 ( complex args >8 )
374
375
dest_dir_cleanup = True
375
376
local_image_path = Path (dest_dir )
376
377
377
- # Set params
378
378
params = {}
379
-
380
- # User/pass
381
- if auth :
382
- usr_pass = auth .split (":" )
383
- params ["username" ] = usr_pass [0 ]
384
- params ["password" ] = usr_pass [- 1 ]
385
- elif oci_username and oci_password :
386
- params ["username" ] = oci_username
387
- params ["password" ] = oci_password
388
-
389
- backend_def .pull (base_image , local_image_path , ** params )
390
- # Extract the absolute path from the files found in the path
391
- files = [file [0 ] for file in _get_files_from_path (model_files_path )]
392
- oci_layers_on_top (local_image_path , files , modelcard )
393
- backend_def .push (local_image_path , oci_ref , ** params )
379
+ with temp_auth_file (auth ) as auth_file :
380
+ if auth_file is not None :
381
+ params ["authfile" ] = auth_file .name
382
+ backend_def .pull (base_image , local_image_path , ** params )
383
+ # Extract the absolute path from the files found in the path
384
+ files = [file [0 ] for file in _get_files_from_path (model_files_path )]
385
+ oci_layers_on_top (local_image_path , files , modelcard )
386
+ backend_def .push (local_image_path , oci_ref , ** params )
394
387
395
388
# Return the OCI URI
396
389
if dest_dir_cleanup :
397
390
shutil .rmtree (dest_dir )
398
391
return f"oci://{ oci_ref } "
399
392
400
393
394
+ @overload
395
+ def temp_auth_file (auth : str ) -> AbstractContextManager [TextIO ]:
396
+ ...
397
+
398
+
399
+ @overload
400
+ def temp_auth_file (auth : None ) -> AbstractContextManager [None ]:
401
+ ...
402
+
403
+
404
+ @contextmanager
405
+ def temp_auth_file (auth : str | None ) -> AbstractContextManager [TextIO | None ]:
406
+ """Create a temporary auth file with optional auth data.
407
+
408
+ If auth is None, yields None. Otherwise creates a temporary JSON file
409
+ containing the auth string and yields the file handle.
410
+ """
411
+ if auth is None :
412
+ yield None
413
+ else :
414
+ with tempfile .NamedTemporaryFile (mode = "w+" , encoding = "utf-8" , suffix = ".json" ) as temp_auth_file :
415
+ temp_auth_file .write (auth )
416
+ temp_auth_file .flush ()
417
+ yield temp_auth_file
418
+
419
+
401
420
def _s3_creds (
402
421
endpoint_url : str | None = None ,
403
422
access_key_id : str | None = None ,
@@ -636,6 +655,66 @@ def _extract_auth_json(auth_data: str) -> str:
636
655
raise ValueError (invalid_json_msg ) from e
637
656
638
657
658
+ def get_auth_reference (image_path : str ) -> str :
659
+ """Parses an arbitrary container image path to extract a valid reference.
660
+
661
+ for use as a key in a container registry auth.json file.
662
+
663
+ Examples:
664
+ 'quay.io/my-org/my-registry:1.0.0' -> 'quay.io/my-org/my-registry'
665
+ 'my-private-registry:5000/team/app:latest' -> 'my-private-registry:5000/team/app'
666
+ 'ubuntu' -> 'docker.io'
667
+ 'ubuntu:22.04' -> 'docker.io'
668
+ 'my-user/my-app' -> 'docker.io'
669
+ 'my-user/my-app:v2' -> 'docker.io'
670
+ 'quay.io/my-org/my-registry@sha256:f1b3f5a2d...' -> 'quay.io/my-org/my-registry'
671
+ 'localhost:5000/my-local-image' -> 'localhost:5000/my-local-image'
672
+ 'localhost:5000/my-local-image:test-tag' -> 'localhost:5000/my-local-image'
673
+ """
674
+ repo_path = image_path
675
+
676
+ # Remove digest if it exists
677
+ if "@" in repo_path :
678
+ repo_path = repo_path .split ("@" , 1 )[0 ]
679
+
680
+ # Separate the tag from the repository path.
681
+ # The tag is what comes after the last colon, but only if that colon
682
+ # is not part of the hostname/port. A colon indicates a tag if it
683
+ # appears after the last slash in the path.
684
+ last_colon = repo_path .rfind (":" )
685
+ last_slash = repo_path .rfind ("/" )
686
+
687
+ if last_colon > last_slash :
688
+ # This is a tag, not a port, so we strip it.
689
+ repo_path = repo_path [:last_colon ]
690
+
691
+ # Handle default Docker Hub images (e.g., 'ubuntu', 'user/repo').
692
+ # The hostname is the part of the path before the first slash.
693
+ first_slash_index = repo_path .find ("/" )
694
+ hostname = repo_path
695
+ if first_slash_index != - 1 :
696
+ hostname = repo_path [:first_slash_index ]
697
+
698
+ # If the hostname part doesn't contain a '.' (like quay.io) or a ':' (like localhost:5000),
699
+ # it's a short name for an image on Docker Hub.
700
+ if "." not in hostname and ":" not in hostname :
701
+ return "docker.io"
702
+
703
+ # For all other images, the full repository path is the reference.
704
+ return repo_path
705
+
706
+
707
+ def create_auth_object (oci_ref : str , username : str , password : str ) -> dict [str : dict [str , dict [str , str ]]]:
708
+ """Create an auth object for container registry authentication.
709
+
710
+ This object can be encoded as json with json.dumps() producing the
711
+ contents for valid authfile.
712
+ """
713
+ auth_ref = get_auth_reference (oci_ref )
714
+ auth_value = base64 .b64encode (f"{ username } :{ password } " .encode ()).decode ("utf-8" )
715
+ return {"auths" : {auth_ref : {"auth" : auth_value }}}
716
+
717
+
639
718
def rand_suffix (size : int = 8 ) -> str :
640
719
"""Generate a random suffix.
641
720
0 commit comments