@@ -103,23 +103,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
103103 if isinstance (trainer , fl .Fabric ):
104104 raise NotImplementedError ("Fabric is not supported yet." )
105105
106- trainer_ckpt_path = self .get_trainer_ckpt_path (model )
107- if trainer_ckpt_path :
108- trainer .ckpt_path = trainer_ckpt_path
109- trainer .checkpoint_callback .last_model_path = trainer_ckpt_path
110- # Load artifacts
111- if getattr (self .restore_config , 'load_artifacts' , False ):
112- if isinstance (trainer_ckpt_path , AdapterPath ):
113- # load tokenizer from the base model during peft resume, in case the first peft checkpoint
114- # is deleted before the current peft checkpoint is saved
115- context_path = trainer_ckpt_path .base_model_path / "context"
116- if not context_path .exists ():
117- context_path = trainer_ckpt_path .base_model_path
118- else :
119- context_path = self .get_context_path (model )
120- model = _try_restore_tokenizer (model , context_path )
121-
122- elif self .restore_config :
106+ if self .restore_config :
123107 new_path = self ._extract_path (
124108 model = model ,
125109 path = self .restore_config .path ,
@@ -139,6 +123,21 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
139123
140124 _try_restore_tokenizer (model , context_path )
141125
126+ elif (trainer_ckpt_path := self .get_trainer_ckpt_path (model )) is not None :
127+ trainer .ckpt_path = trainer_ckpt_path
128+ trainer .checkpoint_callback .last_model_path = trainer_ckpt_path
129+ # Load artifacts
130+ if getattr (self .restore_config , 'load_artifacts' , False ):
131+ if isinstance (trainer_ckpt_path , AdapterPath ):
132+ # load tokenizer from the base model during peft resume, in case the first peft checkpoint
133+ # is deleted before the current peft checkpoint is saved
134+ context_path = trainer_ckpt_path .base_model_path / "context"
135+ if not context_path .exists ():
136+ context_path = trainer_ckpt_path .base_model_path
137+ else :
138+ context_path = self .get_context_path (model )
139+ model = _try_restore_tokenizer (model , context_path )
140+
142141 def _extract_path (
143142 self , model : Optional [io .ConnectorMixin ], path : str , adapter_path : Optional [str ] = None
144143 ) -> BasePath :
0 commit comments