5454from MaxText import pyconfig
5555from MaxText import tokenizer
5656from MaxText import train_utils
57- from MaxText .integration .tunix .tunix_adapter import TunixMaxTextAdapter
5857
5958# Tunix Imports
6059from tunix .distillation import distillation_trainer
@@ -337,21 +336,18 @@ def __next__(self) -> MaxTextTrainingInput:
337336# Model Loading
338337# -----------------------------------------------------------------------------
339338def get_maxtext_model (config : pyconfig .HyperParameters , mesh : jax .sharding .Mesh ) -> nnx .Module :
340- """Loads a MaxText model and wraps it in a Tunix adapter .
339+ """Loads a MaxText model.
341340
342341 Args:
343342 config: The configuration object for this specific model (Student or Teacher).
344343 mesh: The global device mesh for sharding weights.
345344
346345 Returns:
347- A TunixMaxTextAdapter instance wrapping the loaded MaxText model.
346+ The loaded MaxText model.
348347 """
349348 max_logging .log (f"Initializing model: { config .model_name } ..." )
350349 model , _ = model_creation_utils .create_nnx_model (config , mesh = mesh )
351-
352- with mesh :
353- tunix_model = TunixMaxTextAdapter (base_model = model , use_no_op_mappings = True )
354- return tunix_model
350+ return model
355351
356352
357353# -----------------------------------------------------------------------------
@@ -408,22 +404,28 @@ def labels_fn(targets, **kwargs):
408404 mask = jnp .not_equal (targets , pad_id ).astype (one_hot .dtype )[..., None ]
409405 return one_hot * mask
410406
411- def model_forward_fn (model , input_tokens , positions , attention_mask , decoder_segment_ids = None , cache = None , ** kwargs ):
412- """Forward pass wrapper for the MaxText models (Student and Teacher)."""
413- del kwargs # Unused
414- # Tunix adapter ensures __call__ signature matches this
415- outputs = model (
416- input_tokens = input_tokens ,
417- positions = positions ,
418- cache = cache ,
419- attention_mask = attention_mask ,
420- decoder_segment_ids = decoder_segment_ids , # Support sequence packing
421- )
422- return outputs [0 ] # Return logits only
407+ def create_forward_fn (config ):
408+ """Creates a forward function closure that binds the specific model configuration."""
409+
410+ def model_forward_fn (model , input_tokens , positions , attention_mask , decoder_segment_ids = None , cache = None , ** kwargs ):
411+ """Forward pass wrapper adapted for raw MaxText models."""
412+ del kwargs # Unused
413+ del attention_mask # Unused
414+ del cache # Unused
415+
416+ logits = model (
417+ decoder_input_tokens = input_tokens ,
418+ decoder_positions = positions ,
419+ decoder_segment_ids = decoder_segment_ids ,
420+ enable_dropout = config .enable_dropout ,
421+ )
422+ return logits
423+
424+ return model_forward_fn
423425
424- # Both Student and Teacher use the same forward logic via the adapter
425- student_forward_fn = model_forward_fn
426- teacher_forward_fn = model_forward_fn
426+ # Create forward functions for both Student and Teacher
427+ student_forward_fn = create_forward_fn ( student_config )
428+ teacher_forward_fn = create_forward_fn ( teacher_config )
427429
428430 # Use Monitored strategy to enable KL/Soft/Hard Loss logging
429431 strategy = MonitoredLogitStrategy (
@@ -438,7 +440,10 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
438440 optimizer = get_distillation_optimizer (student_config , student_config .steps )
439441
440442 checkpointing_options = checkpoint .CheckpointManagerOptions (
441- save_interval_steps = student_config .checkpoint_period , max_to_keep = student_config .max_num_checkpoints_to_keep
443+ save_interval_steps = student_config .checkpoint_period ,
444+ max_to_keep = student_config .max_num_checkpoints_to_keep ,
445+ enable_async_checkpointing = student_config .async_checkpointing ,
446+ create = True ,
442447 )
443448
444449 profiler_options = None
@@ -477,7 +482,7 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
477482 trainer ._has_aux = True # pylint: disable=protected-access
478483
479484 # 6. Configure Input Mapping
480- # Maps the attributes of MaxTextTrainingInput to the kwargs expected by the models
485+ # Maps the attributes of MaxTextTrainingInput to the kwargs expected by model_forward_fn
481486 trainer = trainer .with_gen_model_input_fn (
482487 lambda batch : {
483488 "input_tokens" : batch .input_tokens ,
@@ -560,13 +565,12 @@ def main(argv: Sequence[str]) -> None:
560565 # We isolate the Teacher from Student CLI arguments (like pruning params).
561566 teacher_overrides = global_config .teacher_overrides
562567
563- # Ensure load_parameters_path is set (check overrides, then env var)
568+ # Ensure load_parameters_path is set in overrides
564569 if not teacher_overrides .get ("load_parameters_path" ):
565- ckpt_path = os .environ .get ("TEACHER_CHECKPOINT_PATH" )
566- if ckpt_path :
567- teacher_overrides ["load_parameters_path" ] = ckpt_path
568- else :
569- max_logging .log ("Warning: No load_parameters_path found for Teacher." )
570+ raise ValueError (
571+ "Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
572+ "in your config or arguments."
573+ )
570574
571575 # Construct sanitized argv: [script_name, config_file]
572576 # This ensures flags like `num_query_heads=16` passed in CLI don't affect the Teacher.
0 commit comments