Skip to content

[WIP] Nanotron fix #656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lighteval/config/lighteval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,11 @@ class LightEvalConfig:
class FullNanotronConfig:
lighteval_config: LightEvalConfig
nanotron_config: "Config"

@property
def generation_parameters(self):
# Return the generation parameters from the lighteval config
# or create default generation parameters if none are set
if self.lighteval_config.generation:
return self.lighteval_config.generation
return GenerationArgs()
36 changes: 18 additions & 18 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def nanotron(
Evaluate models using nanotron as backend.
"""
from nanotron.config import Config, get_config_from_file
from nanotron.config.parallelism_config import ParallelismArgs

from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig
from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig, LightEvalLoggingArgs, LightEvalTasksArgs
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import htrack_block
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
from lighteval.utils.utils import EnvConfig
Expand All @@ -61,23 +61,23 @@ def nanotron(

if not is_nanotron_available():
raise ImportError(NO_NANOTRON_ERROR_MSG)

# Create nanotron config
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")

model_config = get_config_from_file(
checkpoint_config_path,
config_class=Config,
model_config_class=None,
skip_unused_config_keys=True,
skip_null_keys=True,
)

with htrack_block("Load nanotron config"):
# Create nanotron config
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")

model_config = get_config_from_file(
checkpoint_config_path,
config_class=Config,
model_config_class=None,
skip_unused_config_keys=True,
skip_null_keys=True,
)

# We are getting an type error, because the get_config_from_file is not correctly typed,
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore
nanotron_config = FullNanotronConfig(lighteval_config, model_config)
# Load lighteval config
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore

nanotron_config = FullNanotronConfig(lighteval_config, model_config)

evaluation_tracker = EvaluationTracker(
output_dir=lighteval_config.logging.output_dir,
Expand Down
20 changes: 14 additions & 6 deletions src/lighteval/models/nanotron/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,14 @@ def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)

def _model_call(self, inputs: torch.Tensor) -> torch.Tensor:
return self.model(inputs)
position_ids = (
torch.arange(
inputs.shape[1], device=inputs.device, dtype=torch.int32
)
.unsqueeze(0)
.repeat(inputs.shape[0], 1)
)
return self.model(inputs, position_ids)

def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
"""Ending conditions are submitted in several possible formats.
Expand Down Expand Up @@ -711,14 +718,14 @@ def _loglikelihood_single_token(
inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True
)
# batched_inputs, batch_attention, input_lengths, truncated, padded

out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)
position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1)
out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids)

if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
# This process got outputs

# Gather all the output accross TP
out = out.transpose(0, 1).contiguous() # [batch, seq_length, vocab]
out = out.view(*batch_model.input_ids.shape, -1).contiguous() # [batch, seq_length, vocab]

gathered_out = [torch.zeros_like(out) for _ in range(self.parallel_context.tp_pg.size())]
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
Expand Down Expand Up @@ -944,7 +951,8 @@ def _loglikelihood_tokens(
)
# batched_inputs, batch_attention, input_lengths, truncated, padded
with torch.no_grad():
out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)
position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1)
out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids)

if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
# This process got outputs
Expand All @@ -954,7 +962,7 @@ def _loglikelihood_tokens(
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
out = torch.cat(gathered_out, dim=-1)

out = out.transpose(0, 1) # [batch, seq_length, vocab]
out = out.view(*batch_model.input_ids.shape, -1) # [batch, seq_length, vocab]
multi_logits = F.log_softmax(out, dim=-1) # [batch, padding_length, vocab]

logits_sum = []
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from nanotron.parallel.context import ParallelContext
from nanotron.utils import local_ranks_zero_first

from lighteval.models.nanotron_model import NanotronLightevalModel
from lighteval.models.nanotron.nanotron_model import NanotronLightevalModel


import logging
Expand Down Expand Up @@ -187,12 +187,12 @@ def _init_parallelism_manager(self):
def _init_model(self, model_config, model):
logger.info("--- LOADING MODEL ---")
if model_config is not None:
if self.parallel_context:
if self.parallel_context:
return NanotronLightevalModel(
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
if self.pipeline_parameters.nanotron_checkpoint_path
else "",
nanotron_config=self.model_config,
nanotron_config=model_config,
parallel_context=self.parallel_context,
debug_one_layer_model=False,
model_class=None,
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class LightevalTaskConfig:
few_shots_select: Optional[str] = None

# Generation args
output_regex: Optional[str] = None
generation_size: Optional[int] = None
generation_grammar: Optional[TextGenerationInputGrammarType] = None
stop_sequence: Optional[ListLike[str]] = None
Expand All @@ -120,6 +121,7 @@ class LightevalTaskConfig:
must_remove_duplicate_docs: bool = False

version: int = 0
frozen: bool = False

def __post_init__(self):
# If we got a Metrics enums instead of a Metric, we convert
Expand Down
Loading