Skip to content

Commit f0fafa2

Browse files
borisdaymaawaelchlirohitgr7SeanNaren
authored
feat(wandb): add sync_step (#5351)
* docs(wandb): add details to args * feat(wandb): no sync between trainer and W&B steps * style: pep8 * tests(wandb): test sync_step * docs(wandb): add references * docs(wandb): fix typo * feat(wandb): more explicit warning * feat(wandb): order of args * style: Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * style: long line Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent 0c9960b commit f0fafa2

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

pytorch_lightning/loggers/wandb.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,20 @@ class WandbLogger(LightningLoggerBase):
4949
5050
Args:
5151
name: Display name for the run.
52-
save_dir: Path where data is saved.
52+
save_dir: Path where data is saved (wandb dir by default).
5353
offline: Run offline (data can be streamed later to wandb servers).
5454
id: Sets the version, mainly used to resume a previous run.
55+
version: Same as id.
5556
anonymous: Enables or explicitly disables anonymous logging.
56-
version: Sets the version, mainly used to resume a previous run.
5757
project: The name of the project to which this run will belong.
5858
log_model: Save checkpoints in wandb dir to upload on W&B servers.
59-
experiment: WandB experiment object.
6059
prefix: A string to put at the beginning of metric keys.
60+
sync_step: Sync Trainer step with wandb step.
61+
experiment: WandB experiment object. Automatically set when creating a run.
6162
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
6263
:func:`wandb.init` can be passed as keyword arguments in this logger.
6364
64-
Example::
65+
Example:
6566
6667
.. code-block:: python
6768
@@ -74,9 +75,9 @@ class WandbLogger(LightningLoggerBase):
7475
make sure to use `commit=False` so the logging step does not increase.
7576
7677
See Also:
77-
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
78-
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
79-
on how to use W&B with Pytorch Lightning.
78+
- `Tutorial <https://colab.research.google.com/drive/16d1uctGaw2y9KhGBlINNTsWpmlXdJwRW?usp=sharing>`__
79+
on how to use W&B with PyTorch Lightning
80+
- `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__
8081
8182
"""
8283

@@ -86,14 +87,15 @@ def __init__(
8687
self,
8788
name: Optional[str] = None,
8889
save_dir: Optional[str] = None,
89-
offline: bool = False,
90+
offline: Optional[bool] = False,
9091
id: Optional[str] = None,
91-
anonymous: bool = False,
92+
anonymous: Optional[bool] = False,
9293
version: Optional[str] = None,
9394
project: Optional[str] = None,
94-
log_model: bool = False,
95+
log_model: Optional[bool] = False,
9596
experiment=None,
96-
prefix: str = '',
97+
prefix: Optional[str] = '',
98+
sync_step: Optional[bool] = True,
9799
**kwargs
98100
):
99101
if wandb is None:
@@ -102,13 +104,14 @@ def __init__(
102104
super().__init__()
103105
self._name = name
104106
self._save_dir = save_dir
105-
self._anonymous = 'allow' if anonymous else None
107+
self._offline = offline
106108
self._id = version or id
109+
self._anonymous = 'allow' if anonymous else None
107110
self._project = project
108-
self._experiment = experiment
109-
self._offline = offline
110111
self._log_model = log_model
111112
self._prefix = prefix
113+
self._sync_step = sync_step
114+
self._experiment = experiment
112115
self._kwargs = kwargs
113116
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
114117
self._step_offset = 0
@@ -164,11 +167,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
164167
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
165168

166169
metrics = self._add_prefix(metrics)
167-
if step is not None and step + self._step_offset < self.experiment.step:
170+
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
168171
self.warning_cache.warn(
169-
'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
170-
)
171-
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
172+
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
173+
' or try logging with `commit=False` when calling manually `wandb.log`.')
174+
if self._sync_step:
175+
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
176+
elif step is not None:
177+
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
178+
else:
179+
self.experiment.log(metrics)
172180

173181
@property
174182
def save_dir(self) -> Optional[str]:

tests/loggers/test_wandb.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ def test_wandb_logger_init(wandb, recwarn):
4040
wandb.init.assert_called_once()
4141
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
4242

43+
# test sync_step functionality
44+
wandb.init().log.reset_mock()
45+
wandb.init.reset_mock()
46+
wandb.run = None
47+
wandb.init().step = 0
48+
logger = WandbLogger(sync_step=False)
49+
logger.log_metrics({'acc': 1.0})
50+
wandb.init().log.assert_called_once_with({'acc': 1.0})
51+
wandb.init().log.reset_mock()
52+
logger.log_metrics({'acc': 1.0}, step=3)
53+
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})
54+
4355
# mock wandb step
4456
wandb.init().step = 0
4557

0 commit comments

Comments
 (0)