Skip to content

Commit 7b564ad

Browse files
committed
Bug fixes (#1596)
1 parent cae3f81 commit 7b564ad

File tree

3 files changed

+69
-18
lines changed

3 files changed

+69
-18
lines changed

src/sagemaker/modules/local_core/local_container.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
import base64
1717
import os
1818
import re
19+
import shutil
1920
import subprocess
21+
from tempfile import TemporaryDirectory
2022
from typing import Any, Dict, List, Optional
2123
from pydantic import BaseModel, ConfigDict
2224

23-
from sagemaker.local.data import LocalFileDataSource, S3DataSource
2425
from sagemaker.local.image import (
2526
_Volume,
2627
_aws_credentials,
@@ -34,7 +35,7 @@
3435
from sagemaker.modules import logger
3536
from sagemaker.modules.configs import Channel
3637
from sagemaker.session import Session
37-
from sagemaker.utils import ECR_URI_PATTERN, create_tar_file, _module_import_error
38+
from sagemaker.utils import ECR_URI_PATTERN, create_tar_file, _module_import_error, download_folder
3839
from sagemaker_core.main.utils import Unassigned
3940
from sagemaker_core.shapes import DataSource
4041

@@ -420,8 +421,9 @@ def _generate_compose_command(self, wait: bool):
420421
Args:
421422
wait (bool): Whether to wait for the docker command result.
422423
"""
423-
command = [
424-
"docker-compose",
424+
_compose_cmd_prefix = self._get_compose_cmd_prefix()
425+
426+
command = _compose_cmd_prefix + [
425427
"-f",
426428
os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME),
427429
"up",
@@ -502,8 +504,8 @@ def _prepare_training_volumes(
502504
channel_dir = os.path.join(data_dir, channel_name)
503505
os.makedirs(channel_dir, exist_ok=True)
504506

505-
data_source_instance = self._get_data_source_instance(channel.data_source)
506-
volumes.append(_Volume(data_source_instance.get_root_dir(), channel=channel_name).map)
507+
data_source_local_path = self._get_data_source_local_path(channel.data_source)
508+
volumes.append(_Volume(data_source_local_path, channel=channel_name).map)
507509

508510
# If there is a training script directory and it is a local directory,
509511
# mount it to the container.
@@ -518,23 +520,68 @@ def _prepare_training_volumes(
518520

519521
return volumes
520522

521-
def _get_data_source_instance(self, data_source: DataSource):
522-
"""Return an Instance of :class:`sagemaker.local.data.DataSource`.
523-
524-
The instance can handle the provided data_source URI.
523+
def _get_data_source_local_path(self, data_source: DataSource):
524+
"""Return a local data path of :class:`sagemaker.local.data.DataSource`.
525525
526-
data_source can be either file:// or s3://
526+
If the data source is from S3, the data will be downloaded to a temporary
527+
local path.
528+
If the data source is local file, the absolute path will be returned.
527529
528530
Args:
529531
data_source (DataSource): a data source of local file or s3
530532
531533
Returns:
532-
sagemaker.local.data.DataSource: an Instance of a Data Source
534+
str: The local path of the data.
533535
"""
534536
if data_source.s3_data_source != Unassigned():
535537
uri = data_source.s3_data_source.s3_uri
536538
parsed_uri = urlparse(uri)
537-
return S3DataSource(parsed_uri.netloc, parsed_uri.path, self.sagemaker_session)
539+
local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name
540+
download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session)
541+
return local_dir
538542
else:
539-
uri = data_source.file_system_data_source.directory_path
540-
return LocalFileDataSource(uri)
543+
return os.path.abspath(data_source.file_system_data_source.directory_path)
544+
545+
def _get_compose_cmd_prefix(self) -> List[str]:
546+
"""Gets the Docker Compose command.
547+
548+
The method initially looks for 'docker compose' v2
549+
executable, if not found looks for 'docker-compose' executable.
550+
551+
Returns:
552+
List[str]: Docker Compose executable split into list.
553+
554+
Raises:
555+
ImportError: If Docker Compose executable was not found.
556+
"""
557+
compose_cmd_prefix = []
558+
559+
output = None
560+
try:
561+
output = subprocess.check_output(
562+
["docker", "compose", "version"],
563+
stderr=subprocess.DEVNULL,
564+
encoding="UTF-8",
565+
)
566+
except subprocess.CalledProcessError:
567+
logger.info(
568+
"'Docker Compose' is not installed. "
569+
"Proceeding to check for 'docker-compose' CLI."
570+
)
571+
572+
if output and "v2" in output.strip():
573+
logger.info("'Docker Compose' found using Docker CLI.")
574+
compose_cmd_prefix.extend(["docker", "compose"])
575+
return compose_cmd_prefix
576+
577+
if shutil.which("docker-compose") is not None:
578+
logger.info("'Docker Compose' found using Docker Compose CLI.")
579+
compose_cmd_prefix.extend(["docker-compose"])
580+
return compose_cmd_prefix
581+
582+
raise ImportError(
583+
"Docker Compose is not installed. "
584+
"Local Mode features will not work without docker compose. "
585+
"For more information on how to install 'docker compose', please, see "
586+
"https://docs.docker.com/compose/install/"
587+
)

src/sagemaker/modules/testing_notebooks/local_model_trainer.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"model_trainer = ModelTrainer(\n",
5353
" training_image=hugging_face_image,\n",
5454
" source_code=source_code,\n",
55-
" training_input_mode=Mode.LOCAL_CONTAINER,\n",
55+
" training_mode=Mode.LOCAL_CONTAINER,\n",
5656
")"
5757
]
5858
},

src/sagemaker/modules/train/model_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,12 @@ def train(
385385
container_entrypoint = None
386386
container_arguments = None
387387
if self.source_code:
388-
389-
drivers_dir = TemporaryDirectory()
388+
if self.training_mode == Mode.LOCAL_CONTAINER:
389+
drivers_dir = TemporaryDirectory(
390+
prefix=os.path.join(self.local_container_root + "/")
391+
)
392+
else:
393+
drivers_dir = TemporaryDirectory()
390394
shutil.copytree(SM_DRIVERS_LOCAL_PATH, drivers_dir.name, dirs_exist_ok=True)
391395

392396
# If source code is provided, create a channel for the source code

0 commit comments

Comments
 (0)