Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from typing import Any, Dict, List, Optional, Union

import yaml
from packaging import version

from .logging import get_logger
from .state import PartialState
from .utils import (
LoggerType,
compare_versions,
is_aim_available,
is_clearml_available,
is_comet_ml_available,
Expand Down Expand Up @@ -402,11 +404,16 @@ class CometMLTracker(GeneralTracker):

API keys must be stored in a Comet config file.

Note:
For `comet_ml` versions < 3.41.0, additional keyword arguments are passed to `comet_ml.Experiment` instead:
https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/#comet_ml.Experiment.__init__

Args:
run_name (`str`):
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `Experiment.__init__` method.
Additional key word arguments passed along to the `comet_ml.start` method:
https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/start/
"""

name = "comet_ml"
Expand All @@ -417,9 +424,15 @@ def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name

from comet_ml import Experiment
import comet_ml

comet_version = version.parse(comet_ml.__version__)
if compare_versions(comet_version, ">=", "3.41.0"):
self.writer = comet_ml.start(project_name=run_name, **kwargs)
else:
logger.info("Update `comet_ml` (>=3.41.0) for experiment reuse and offline support.")
self.writer = comet_ml.Experiment(project_name=run_name, **kwargs)

self.writer = Experiment(project_name=run_name, **kwargs)
logger.debug(f"Initialized CometML project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
Expand All @@ -440,7 +453,7 @@ def store_init_configuration(self, values: dict):
`str`, `float`, `int`, or `None`.
"""
self.writer.log_parameters(values)
logger.debug("Stored initial configuration hyperparameters to CometML")
logger.debug("Stored initial configuration hyperparameters to Comet")

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
Expand All @@ -466,15 +479,15 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs):
self.writer.log_other(k, v, **kwargs)
elif isinstance(v, dict):
self.writer.log_metrics(v, step=step, **kwargs)
logger.debug("Successfully logged to CometML")
logger.debug("Successfully logged to Comet")

@on_main_process
def finish(self):
"""
Closes `comet-ml` writer
Flush `comet-ml` writer
"""
self.writer.end()
logger.debug("CometML run closed")
logger.debug("Comet run flushed")


class AimTracker(GeneralTracker):
Expand Down
19 changes: 7 additions & 12 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@


if is_comet_ml_available():
from comet_ml import OfflineExperiment
from comet_ml import ExperimentConfig

if is_tensorboard_available():
import struct
Expand Down Expand Up @@ -307,16 +307,7 @@ def test_log_artifacts(self):
)


# Comet has a special `OfflineExperiment` we need to use for testing
def offline_init(self, run_name: str, tmpdir: str):
self.run_name = run_name
self.writer = OfflineExperiment(project_name=run_name, offline_directory=tmpdir)
logger.info(f"Initialized offline CometML project {self.run_name}")
logger.info("Make sure to log any initial configurations with `self.store_init_configuration` before training!")


@require_comet_ml
@mock.patch.object(CometMLTracker, "__init__", offline_init)
class CometMLTest(unittest.TestCase):
@staticmethod
def get_value_from_key(log_list, key: str, is_param: bool = False):
Expand All @@ -337,7 +328,9 @@ def get_value_from_key(log_list, key: str, is_param: bool = False):

def test_init_trackers(self):
with tempfile.TemporaryDirectory() as d:
tracker = CometMLTracker("test_project_with_config", d)
tracker = CometMLTracker(
"test_project_with_config", online=False, experiment_config=ExperimentConfig(offline_directory=d)
)
accelerator = Accelerator(log_with=tracker)
config = {"num_iterations": 12, "learning_rate": 1e-2, "some_boolean": False, "some_string": "some_value"}
accelerator.init_trackers(None, config)
Expand All @@ -355,7 +348,9 @@ def test_init_trackers(self):

def test_log(self):
with tempfile.TemporaryDirectory() as d:
tracker = CometMLTracker("test_project_with_config", d)
tracker = CometMLTracker(
"test_project_with_config", online=False, experiment_config=ExperimentConfig(offline_directory=d)
)
accelerator = Accelerator(log_with=tracker)
accelerator.init_trackers(None)
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
Expand Down