Skip to content

Commit c1d161d

Browse files
committed
Refactor model loading in distillation to remove Tunix adapter.
1 parent 724b115 commit c1d161d

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

src/MaxText/distillation/train_distill.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
from MaxText import pyconfig
5555
from MaxText import tokenizer
5656
from MaxText import train_utils
57-
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
5857

5958
# Tunix Imports
6059
from tunix.distillation import distillation_trainer
@@ -337,21 +336,18 @@ def __next__(self) -> MaxTextTrainingInput:
337336
# Model Loading
338337
# -----------------------------------------------------------------------------
339338
def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) -> nnx.Module:
340-
"""Loads a MaxText model and wraps it in a Tunix adapter.
339+
"""Loads a MaxText model.
341340
342341
Args:
343342
config: The configuration object for this specific model (Student or Teacher).
344343
mesh: The global device mesh for sharding weights.
345344
346345
Returns:
347-
A TunixMaxTextAdapter instance wrapping the loaded MaxText model.
346+
The loaded MaxText model.
348347
"""
349348
max_logging.log(f"Initializing model: {config.model_name}...")
350349
model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
351-
352-
with mesh:
353-
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=True)
354-
return tunix_model
350+
return model
355351

356352

357353
# -----------------------------------------------------------------------------
@@ -408,22 +404,28 @@ def labels_fn(targets, **kwargs):
408404
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
409405
return one_hot * mask
410406

411-
def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs):
412-
"""Forward pass wrapper for the MaxText models (Student and Teacher)."""
413-
del kwargs # Unused
414-
# Tunix adapter ensures __call__ signature matches this
415-
outputs = model(
416-
input_tokens=input_tokens,
417-
positions=positions,
418-
cache=cache,
419-
attention_mask=attention_mask,
420-
decoder_segment_ids=decoder_segment_ids, # Support sequence packing
421-
)
422-
return outputs[0] # Return logits only
407+
def create_forward_fn(config):
408+
"""Creates a forward function closure that binds the specific model configuration."""
409+
410+
def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_segment_ids=None, cache=None, **kwargs):
411+
"""Forward pass wrapper adapted for raw MaxText models."""
412+
del kwargs # Unused
413+
del attention_mask # Unused
414+
del cache # Unused
415+
416+
logits = model(
417+
decoder_input_tokens=input_tokens,
418+
decoder_positions=positions,
419+
decoder_segment_ids=decoder_segment_ids,
420+
enable_dropout=config.enable_dropout,
421+
)
422+
return logits
423+
424+
return model_forward_fn
423425

424-
# Both Student and Teacher use the same forward logic via the adapter
425-
student_forward_fn = model_forward_fn
426-
teacher_forward_fn = model_forward_fn
426+
# Create forward functions for both Student and Teacher
427+
student_forward_fn = create_forward_fn(student_config)
428+
teacher_forward_fn = create_forward_fn(teacher_config)
427429

428430
# Use Monitored strategy to enable KL/Soft/Hard Loss logging
429431
strategy = MonitoredLogitStrategy(
@@ -438,7 +440,10 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
438440
optimizer = get_distillation_optimizer(student_config, student_config.steps)
439441

440442
checkpointing_options = checkpoint.CheckpointManagerOptions(
441-
save_interval_steps=student_config.checkpoint_period, max_to_keep=student_config.max_num_checkpoints_to_keep
443+
save_interval_steps=student_config.checkpoint_period,
444+
max_to_keep=student_config.max_num_checkpoints_to_keep,
445+
enable_async_checkpointing=student_config.async_checkpointing,
446+
create=True,
442447
)
443448

444449
profiler_options = None
@@ -477,7 +482,7 @@ def model_forward_fn(model, input_tokens, positions, attention_mask, decoder_seg
477482
trainer._has_aux = True # pylint: disable=protected-access
478483

479484
# 6. Configure Input Mapping
480-
# Maps the attributes of MaxTextTrainingInput to the kwargs expected by the models
485+
# Maps the attributes of MaxTextTrainingInput to the kwargs expected by model_forward_fn
481486
trainer = trainer.with_gen_model_input_fn(
482487
lambda batch: {
483488
"input_tokens": batch.input_tokens,
@@ -560,13 +565,12 @@ def main(argv: Sequence[str]) -> None:
560565
# We isolate the Teacher from Student CLI arguments (like pruning params).
561566
teacher_overrides = global_config.teacher_overrides
562567

563-
# Ensure load_parameters_path is set (check overrides, then env var)
568+
# Ensure load_parameters_path is set in overrides
564569
if not teacher_overrides.get("load_parameters_path"):
565-
ckpt_path = os.environ.get("TEACHER_CHECKPOINT_PATH")
566-
if ckpt_path:
567-
teacher_overrides["load_parameters_path"] = ckpt_path
568-
else:
569-
max_logging.log("Warning: No load_parameters_path found for Teacher.")
570+
raise ValueError(
571+
"Teacher model path is missing! You must provide 'teacher_overrides.load_parameters_path' "
572+
"in your config or arguments."
573+
)
570574

571575
# Construct sanitized argv: [script_name, config_file]
572576
# This ensures flags like `num_query_heads=16` passed in CLI don't affect the Teacher.

0 commit comments

Comments
 (0)