[RLlib] Removes device infos from state when saving RModules to checkpoints/states.#43906
Conversation
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
|
I'm not sure I understand the exact need for this additional option to the API. My main argument against this would be: When stuff gets stored to a checkpoint, it should be stored in a device-independent fashion. So the issue at hand here is NOT the loading from the checkpoint, but the saving to the checkpoint beforehand, which - I'm guessing - probably happened in torch.cuda tensors, NOT in numpy format. Can we rather take the opposite approach to keep the mental model of what a checkpoint should be clean? Always save weights (and other tensor/matrix states) as numpy arrays, never as torch or tf tensors. When loading from a checkpoint, the sequence should be something like:
|
I agree with your argument that we should ensure that checkpointing is device-independent. This should be the cleanest way of doing this. We should investigate, where exactly this device-dependent checkpointing takes place and fix the problem there. I am, however, not so sure, if 3. describes how the workflow runs right now. Here is what makes me wonder: If the |
Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
…tore state dict now in numpy format which makes it device-agnostic. Signed-off-by: Simon Zehnder <simon.zehnder@gmail.com>
RModules to checkpoints/states.
rllib/core/rl_module/rl_module.py
Outdated
|
|
||
| Args: | ||
| checkpoint_dir_path: The directory to load the checkpoint from. | ||
| map_location: The device on which the module resides. |
rllib/core/rl_module/marl_module.py
Outdated
| modules_to_load: The modules whose state is to be loaded from the path. If | ||
| this is None, all modules that are checkpointed will be loaded into this | ||
| marl module. | ||
| map_location: The device the module resides on. |
| def save_state(self, dir: Union[str, pathlib.Path]) -> None: | ||
| path = str(pathlib.Path(dir) / self._module_state_file_name()) | ||
| torch.save(self.state_dict(), path) | ||
| torch.save(convert_to_numpy(self.state_dict()), path) |
There was a problem hiding this comment.
Perfect! This should work.
sven1977
left a comment
There was a problem hiding this comment.
Looks good! Thanks for this important fix @simonsays1980 !
Just two nits on the docstrings.
RModules to checkpoints/states.RModules to checkpoints/states.
| def load_state( | ||
| self, | ||
| dir: Union[str, pathlib.Path], | ||
| ) -> None: |
There was a problem hiding this comment.
| def load_state( | |
| self, | |
| dir: Union[str, pathlib.Path], | |
| ) -> None: | |
| def load_state(self, dir: Union[str, pathlib.Path]) -> None: |
Signed-off-by: Sven Mika <sven@anyscale.io>
Why are these changes needed?
When loading an
RLModuleon CPU from a checkpoint/state that was created from a replica on GPU, an error occurs. This PR fixes this error by forcing the module to save its state in form ofnumpy.NdArrays.Related issue number
Closes #43905
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.