Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class CometMLTracker(GeneralTracker):
run_name (`str`):
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `Experiment.__init__` method.
Additional key word arguments passed along to the `comet_ml.start` method.
"""

name = "comet_ml"
Expand All @@ -417,9 +417,9 @@ def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name

from comet_ml import Experiment
from comet_ml import start
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way we should do this is via an import check. Do we know if it's 3.47.6 that introduced this? Then based on the version either use Experiment/end or start/flush

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We released the last fixes to comet_ml.start in the version 3.41.0 back in April 2024 (https://pypi.org/project/comet-ml/3.41.0/). I have considered using Experiment/start based on the package version and while the new API is 99% compatible, there is new accepted by start (like the experiment name or tags). That could make documenting the callback slightly more complicated.
I would personally prefer to set the minimum version of comet to 3.41.0 as it would make it clearer for end-user what parameters they can pass. But it's only a small preference, I'm also fine with having a different behavior based on the comet version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< 1 year is a bit too soon to just drop IMO. Let's please do both. Especially as 50% of your pypi downloads still are < 3.4.1


self.writer = Experiment(project_name=run_name, **kwargs)
self.writer = start(project_name=run_name, **kwargs)
logger.debug(f"Initialized CometML project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
Expand All @@ -440,7 +440,7 @@ def store_init_configuration(self, values: dict):
`str`, `float`, `int`, or `None`.
"""
self.writer.log_parameters(values)
logger.debug("Stored initial configuration hyperparameters to CometML")
logger.debug("Stored initial configuration hyperparameters to Comet")

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
Expand All @@ -466,15 +466,15 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs):
self.writer.log_other(k, v, **kwargs)
elif isinstance(v, dict):
self.writer.log_metrics(v, step=step, **kwargs)
logger.debug("Successfully logged to CometML")
logger.debug("Successfully logged to Comet")

@on_main_process
def finish(self):
"""
Closes `comet-ml` writer
Flush `comet-ml` writer
"""
self.writer.end()
logger.debug("CometML run closed")
self.writer.flush()
logger.debug("Comet run flushed")


class AimTracker(GeneralTracker):
Expand Down
19 changes: 7 additions & 12 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@


if is_comet_ml_available():
from comet_ml import OfflineExperiment
from comet_ml import ExperimentConfig

if is_tensorboard_available():
import struct
Expand Down Expand Up @@ -201,16 +201,7 @@ def test_wandb(self):
assert logged_items["_step"] == "0"


# Comet has a special `OfflineExperiment` we need to use for testing
def offline_init(self, run_name: str, tmpdir: str):
self.run_name = run_name
self.writer = OfflineExperiment(project_name=run_name, offline_directory=tmpdir)
logger.info(f"Initialized offline CometML project {self.run_name}")
logger.info("Make sure to log any initial configurations with `self.store_init_configuration` before training!")


@require_comet_ml
@mock.patch.object(CometMLTracker, "__init__", offline_init)
class CometMLTest(unittest.TestCase):
@staticmethod
def get_value_from_key(log_list, key: str, is_param: bool = False):
Expand All @@ -231,7 +222,9 @@ def get_value_from_key(log_list, key: str, is_param: bool = False):

def test_init_trackers(self):
with tempfile.TemporaryDirectory() as d:
tracker = CometMLTracker("test_project_with_config", d)
tracker = CometMLTracker(
"test_project_with_config", online=False, experiment_config=ExperimentConfig(offline_directory=d)
)
accelerator = Accelerator(log_with=tracker)
config = {"num_iterations": 12, "learning_rate": 1e-2, "some_boolean": False, "some_string": "some_value"}
accelerator.init_trackers(None, config)
Expand All @@ -249,7 +242,9 @@ def test_init_trackers(self):

def test_log(self):
with tempfile.TemporaryDirectory() as d:
tracker = CometMLTracker("test_project_with_config", d)
tracker = CometMLTracker(
"test_project_with_config", online=False, experiment_config=ExperimentConfig(offline_directory=d)
)
accelerator = Accelerator(log_with=tracker)
accelerator.init_trackers(None)
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
Expand Down