Skip to content

[WIP] Fix nanotron compatibility #706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions src/lighteval/config/lighteval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,11 @@ class LightEvalConfig:
class FullNanotronConfig:
lighteval_config: LightEvalConfig
nanotron_config: "Config"

@property
def generation_parameters(self):
# Return the generation parameters from the lighteval config
# or create default generation parameters if none are set
if self.lighteval_config.generation:
return self.lighteval_config.generation
return GenerationArgs()
51 changes: 24 additions & 27 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
from typer import Option
from typing_extensions import Annotated


CACHE_DIR: str = os.getenv("HF_HOME", "/scratch")

HELP_PANEL_NAME_1 = "Common Parameters"
HELP_PANEL_NAME_2 = "Logging Parameters"
HELP_PANEL_NAME_3 = "Debug Parameters"
Expand All @@ -42,43 +39,44 @@ def nanotron(
checkpoint_config_path: Annotated[
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
],
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")],
cache_dir: Annotated[str, Option(help="Cache directory for datasets and models.")] = CACHE_DIR,
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")]
):
"""
Evaluate models using nanotron as backend.
"""
from nanotron.config import Config, get_config_from_file
from nanotron.config.parallelism_config import ParallelismArgs

from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig
from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig, LightEvalLoggingArgs, LightEvalTasksArgs
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import htrack_block
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
from lighteval.utils.utils import EnvConfig

env_config = EnvConfig(token=os.getenv("HF_TOKEN"), cache_dir=cache_dir)

if not is_nanotron_available():
raise ImportError(NO_NANOTRON_ERROR_MSG)

with htrack_block("Load nanotron config"):
# Create nanotron config
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")

model_config = get_config_from_file(
checkpoint_config_path,
config_class=Config,
model_config_class=None,
skip_unused_config_keys=True,
skip_null_keys=True,
)

# We are getting an type error, because the get_config_from_file is not correctly typed,
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore
nanotron_config = FullNanotronConfig(lighteval_config, model_config)
model_config = get_config_from_file(
checkpoint_config_path,
config_class=Config,
model_config_class=None,
skip_unused_config_keys=True,
skip_null_keys=True,
)
model_config = get_config_from_file(
checkpoint_config_path,
config_class=Config,
model_config_class=None,
skip_unused_config_keys=True,
skip_null_keys=True,
)

# We are getting an type error, because the get_config_from_file is not correctly typed,
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore
nanotron_config = FullNanotronConfig(lighteval_config, model_config)

evaluation_tracker = EvaluationTracker(
output_dir=lighteval_config.logging.output_dir,
hub_results_org=lighteval_config.logging.results_org,
Expand All @@ -92,12 +90,11 @@ def nanotron(

pipeline_parameters = PipelineParameters(
launcher_type=ParallelismManager.NANOTRON,
env_config=env_config,
job_id=os.environ.get("SLURM_JOB_ID", 0),
nanotron_checkpoint_path=checkpoint_config_path,
dataset_loading_processes=lighteval_config.tasks.dataset_loading_processes,
custom_tasks_directory=lighteval_config.tasks.custom_tasks,
override_batch_size=lighteval_config.batch_size,
# override_batch_size=lighteval_config.batch_size,
num_fewshot_seeds=1,
max_samples=lighteval_config.tasks.max_samples,
use_chat_template=False,
Expand Down
48 changes: 20 additions & 28 deletions src/lighteval/models/nanotron/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
from lighteval.utils.imports import is_nanotron_available
from lighteval.utils.parallelism import find_executable_batch_size
from lighteval.utils.utils import EnvConfig, as_list
from lighteval.utils.utils import as_list


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,7 +101,6 @@ def __init__(
trust_remote_code: bool = False,
debug_one_layer_model: bool = False,
model_class: Optional[Type] = None,
env_config: EnvConfig = None,
):
"""Initializes a nanotron model for evaluation.
Args:
Expand All @@ -115,6 +114,10 @@ def __init__(
self._max_length = max_length
self.parallel_config = parallel_config
self.parallel_context = parallel_context
if hasattr(lighteval_config, "batch_size"):
self.batch_size = lighteval_config.batch_size
else:
self.batch_size = None

if parallel_config.pp > 1:
# To implement PP parallelism we need to think about how we want to sync the output for the PP ranks without outputs
Expand All @@ -138,7 +141,6 @@ def __init__(
self._add_special_tokens = add_special_tokens
self._tokenizer = self._create_auto_tokenizer(
pretrained=tokenizer.tokenizer_name_or_path,
env_config=env_config,
trust_remote_code=trust_remote_code,
)
self._tokenizer.model_max_length = self.max_length
Expand Down Expand Up @@ -230,23 +232,18 @@ def _create_auto_tokenizer(
*,
pretrained: str,
tokenizer: Optional[str] = None,
env_config: EnvConfig = None,
trust_remote_code: bool = False,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""

try:
tokenizer = AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
cache_dir=env_config.cache_dir,
token=env_config.token,
trust_remote_code=trust_remote_code,
)
except RecursionError:
tokenizer = AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
cache_dir=env_config.cache_dir,
token=env_config.token,
unk_token="<unk>",
trust_remote_code=trust_remote_code,
)
Expand Down Expand Up @@ -305,9 +302,9 @@ def max_length(self) -> int:
def device(self) -> Union[int, str, torch.device]:
return "cuda"

def _get_batch_size(self, max_input_length: int, override_bs: int = 0, starting_batch_size: int = 512) -> int:
if override_bs:
return override_bs
def _get_batch_size(self, max_input_length: int, starting_batch_size: int = 512) -> int:
if self.batch_size is not None:
return self.batch_size
logger.warning("Detecting largest batch size")

@find_executable_batch_size(
Expand Down Expand Up @@ -343,7 +340,9 @@ def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)

def _model_call(self, inputs: torch.Tensor) -> torch.Tensor:
return self.model(inputs)
# This is only called for detecting the batch size so we just need a mock input_mask
input_mask = torch.ones_like(inputs)
return self.model(inputs, input_mask)

def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
"""Ending conditions are submitted in several possible formats.
Expand Down Expand Up @@ -400,7 +399,7 @@ def _check_continuations_start_space(self, continuation: str) -> str:
return continuation

def loglikelihood_single_token(
self, requests: List[Tuple[str, dict]], override_bs=0
self, requests: List[Tuple[str, dict]],
) -> List[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
Expand Down Expand Up @@ -433,11 +432,10 @@ def loglikelihood_single_token(

return self._loglikelihood_single_token(
requests,
override_bs=override_bs,
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
)

def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None) -> List[LoglikelihoodResponse]:
def loglikelihood(self, requests: List[LoglikelihoodRequest]) -> List[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
Expand All @@ -455,12 +453,11 @@ def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None)

return self._loglikelihood_tokens(
requests,
override_bs=override_bs,
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
)

def loglikelihood_rolling(
self, requests: List[LoglikelihoodRollingRequest], override_bs: int = 0
self, requests: List[LoglikelihoodRollingRequest],
) -> List[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
for request in tqdm(
Expand All @@ -471,7 +468,6 @@ def loglikelihood_rolling(

results = self._loglikelihood_tokens(
requests,
override_bs=override_bs,
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
return_bool_score=False,
)
Expand Down Expand Up @@ -637,7 +633,7 @@ def _get_subsets(self, dataset, num_dataset_splits):

@torch.inference_mode()
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = 0, num_dataset_splits: int = 1
self, requests, disable_tqdm: bool = False, num_dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenResponse]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
res = []
Expand Down Expand Up @@ -665,7 +661,7 @@ def _loglikelihood_single_token(
context_enc = dataset[0].tokenized_context
max_context = len(context_enc[-self.max_length :])
batch_size = self._get_batch_size(
override_bs=override_bs, max_input_length=max_context, starting_batch_size=starting_batch_size
max_input_length=max_context, starting_batch_size=starting_batch_size
)

starting_batch_size = batch_size * 2 # for the next round
Expand Down Expand Up @@ -711,14 +707,13 @@ def _loglikelihood_single_token(
inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True
)
# batched_inputs, batch_attention, input_lengths, truncated, padded

out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)

if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
# This process got outputs

# Gather all the output across TP
out = out.transpose(0, 1).contiguous() # [batch, seq_length, vocab]
# Gather all the output accross TP
out = out.view(*batch_model.input_ids.shape, -1).contiguous() # [batch, seq_length, vocab]

gathered_out = [torch.zeros_like(out) for _ in range(self.parallel_context.tp_pg.size())]
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
Expand Down Expand Up @@ -866,7 +861,6 @@ def _loglikelihood_tokens(
self,
requests,
disable_tqdm: bool = False,
override_bs: int = -1,
num_dataset_splits: int = 1,
return_bool_score: bool = True,
) -> List[LoglikelihoodResponse]:
Expand Down Expand Up @@ -898,7 +892,7 @@ def _loglikelihood_tokens(
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])

batch_size = self._get_batch_size(
override_bs=override_bs, max_input_length=max_context, starting_batch_size=starting_batch_size
max_input_length=max_context, starting_batch_size=starting_batch_size
)
starting_batch_size = batch_size * 2 # for the next round

Expand Down Expand Up @@ -954,7 +948,7 @@ def _loglikelihood_tokens(
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
out = torch.cat(gathered_out, dim=-1)

out = out.transpose(0, 1) # [batch, seq_length, vocab]
out = out.view(*batch_model.input_ids.shape, -1) # [batch, seq_length, vocab]
multi_logits = F.log_softmax(out, dim=-1) # [batch, padding_length, vocab]

logits_sum = []
Expand Down Expand Up @@ -1100,7 +1094,6 @@ def greedy_until(
self,
requests: List[GreedyUntilRequest],
disable_tqdm: bool = False,
override_bs: int = -1,
num_dataset_splits: int = 1,
) -> List[GenerativeResponse]:
"""Greedy generation until a stop token is generated."""
Expand Down Expand Up @@ -1140,7 +1133,6 @@ def greedy_until(
max_input_length = min(len(context_enc) + max_gen, self.max_length)

batch_size = self._get_batch_size(
override_bs=override_bs,
max_input_length=max_input_length,
starting_batch_size=starting_batch_size,
)
Expand Down
11 changes: 5 additions & 6 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import re
import shutil
from contextlib import nullcontext
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import timedelta
from enum import Enum, auto

Expand Down Expand Up @@ -72,7 +72,7 @@
from nanotron.parallel.context import ParallelContext
from nanotron.utils import local_ranks_zero_first

from lighteval.models.nanotron_model import NanotronLightevalModel
from lighteval.models.nanotron.nanotron_model import NanotronLightevalModel


import logging
Expand Down Expand Up @@ -154,8 +154,7 @@ def __init__(
self._metric_options = metric_options or {}
self.accelerator, self.parallel_context = self._init_parallelism_manager()
self.model = self._init_model(model_config, model)

generation_parameters = model_config.generation_parameters.model_dump() if model_config else {}
generation_parameters = asdict(model_config.generation_parameters) if model_config and hasattr(model_config, "generation_parameters") else {}

self.evaluation_tracker.general_config_logger.log_model_info(generation_parameters, self.model.model_info)
self._init_random_seeds()
Expand Down Expand Up @@ -186,12 +185,12 @@ def _init_parallelism_manager(self):
def _init_model(self, model_config, model):
logger.info("--- LOADING MODEL ---")
if model_config is not None:
if self.parallel_context:
if self.parallel_context:
return NanotronLightevalModel(
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
if self.pipeline_parameters.nanotron_checkpoint_path
else "",
nanotron_config=self.model_config,
nanotron_config=model_config,
parallel_context=self.parallel_context,
debug_one_layer_model=False,
model_class=None,
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class LightevalTaskConfig:
few_shots_select: Optional[str] = None

# Generation args
output_regex: Optional[str] = None
generation_size: Optional[int] = None
generation_grammar: Optional[TextGenerationInputGrammarType] = None
stop_sequence: Optional[ListLike[str]] = None
Expand All @@ -120,6 +121,7 @@ class LightevalTaskConfig:
must_remove_duplicate_docs: bool = False

version: int = 0
frozen: bool = False

def __post_init__(self):
# If we got a Metrics enums instead of a Metric, we convert
Expand Down
Loading