Skip to content

Commit 09616b9

Browse files
authored
Add typing to speech_to_text_finetune.py (#15326)
* Fix typing Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com> * Add typing Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com> * Add types following copilot review Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com> * isort Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com> --------- Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com>
1 parent 1ef5781 commit 09616b9

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

examples/asr/asr_adapters/train_asr_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"""
8484
import os
8585
from dataclasses import is_dataclass
86+
from typing import Union
8687

8788
import lightning.pytorch as pl
8889
from omegaconf import DictConfig, OmegaConf, open_dict
@@ -126,7 +127,7 @@ def update_model_cfg(original_cfg, new_cfg):
126127
return new_cfg
127128

128129

129-
def add_global_adapter_cfg(model, global_adapter_cfg):
130+
def add_global_adapter_cfg(model: ASRModel, global_adapter_cfg: Union[DictConfig, dict]):
130131
# Convert to DictConfig from dict or Dataclass
131132
if is_dataclass(global_adapter_cfg):
132133
global_adapter_cfg = OmegaConf.structured(global_adapter_cfg)

examples/asr/speech_to_text_finetune.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@
5454
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
5555
"""
5656
import time
57+
from typing import Union
58+
5759
import lightning.pytorch as pl
58-
from omegaconf import OmegaConf
60+
from omegaconf import DictConfig, OmegaConf
5961

6062
from nemo.collections.asr.models import ASRModel
6163
from nemo.core.config import hydra_runner
@@ -65,7 +67,7 @@
6567
from nemo.utils.trainer_utils import resolve_trainer_cfg
6668

6769

68-
def get_base_model(trainer, cfg):
70+
def get_base_model(trainer: pl.Trainer, cfg: DictConfig) -> ASRModel:
6971
"""
7072
Returns the base model to be fine-tuned.
7173
Currently supports two types of initializations:
@@ -112,7 +114,7 @@ def get_base_model(trainer, cfg):
112114
return asr_model
113115

114116

115-
def check_vocabulary(asr_model, cfg):
117+
def check_vocabulary(asr_model: ASRModel, cfg: DictConfig) -> ASRModel:
116118
"""
117119
Checks if the decoder and vocabulary of the model needs to be updated.
118120
If either of them needs to be updated, it updates them and returns the updated model.
@@ -139,7 +141,7 @@ def check_vocabulary(asr_model, cfg):
139141
return asr_model
140142

141143

142-
def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
144+
def update_tokenizer(asr_model: ASRModel, tokenizer_dir: Union[str, DictConfig], tokenizer_type: str) -> ASRModel:
143145
"""
144146
Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size
145147
of the new tokenizer differs from that of the loaded model.
@@ -173,7 +175,7 @@ def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type):
173175
return asr_model
174176

175177

176-
def setup_dataloaders(asr_model, cfg):
178+
def setup_dataloaders(asr_model: ASRModel, cfg: DictConfig) -> ASRModel:
177179
"""
178180
Sets up the training, validation and test dataloaders for the model.
179181
Args:

0 commit comments

Comments
 (0)