@@ -75,8 +75,8 @@ def __init__(
7575 self .engine_config = engine_config
7676 self .optimizer_config = optimizer_config
7777 self .checkpoint_config = checkpoint_config
78- assert self . engine_config . data_parallel_mode == "fsdp2" , " VeOmniEngine only supports fsdp2."
79-
78+ # VeOmniEngine only supports fsdp2.
79+ self . data_parallel_mode = "fsdp2"
8080 self .rank = dist .get_rank ()
8181
8282 parallel_state .init_parallel_state (
@@ -88,7 +88,7 @@ def __init__(
8888 pp_size = self .engine_config .pipeline_parallel_size ,
8989 cp_size = self .engine_config .context_parallel_size ,
9090 ulysses_size = self .engine_config .ulysses_parallel_size ,
91- dp_mode = self .engine_config . data_parallel_mode ,
91+ dp_mode = self .data_parallel_mode ,
9292 )
9393
9494 if self .engine_config .full_determinism :
@@ -155,7 +155,7 @@ def _build_optimizer(self, module):
155155 )
156156 get_optimizer_pre_hook = getattr (module , "get_optimizer_pre_hook" , None )
157157 if get_optimizer_pre_hook is not None :
158- optimizer_pre_hook = get_optimizer_pre_hook (module , module .config , self .engine_config . data_parallel_mode )
158+ optimizer_pre_hook = get_optimizer_pre_hook (module , module .config , self .data_parallel_mode )
159159 optimizer .register_step_pre_hook (optimizer_pre_hook )
160160
161161 return optimizer
0 commit comments