Skip to content

Add context_logits for eval accuracy calculation in case of multi token prediction tasks#11753

Merged
oyilmaz-nvidia merged 11 commits intomainfrom
athitten/add_context_logits
Jan 17, 2025
Merged

Add context_logits for eval accuracy calculation in case of multi token prediction tasks#11753
oyilmaz-nvidia merged 11 commits intomainfrom
athitten/add_context_logits

Conversation

@athitten
Copy link
Collaborator

@athitten athitten commented Jan 3, 2025

What does this PR do ?

This PR adds the following changes:

  1. Uses context_logits to compute logProbs in case of eval benchmarks that have multiple token prediction. For ex: arc_challenge, arc_easy, winogrande, copa etc., (where as MMLU, lambada are single token prediction benchmarks. Uses generation_logits for these benchmarks to avoid large payload from context_logits )

  2. In order to get context_logits for evals, the PR exposes gather_context_logits and output_context logits in export and deploy files similar to previously existed generation_logits

  3. Introduces requirements_eval.txt file to install lm-eval-harness in NeMo FW containers.

Collection: [Note which collection this PR will affect]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@athitten athitten changed the base branch from main to athitten/eval_server_ready_check January 3, 2025 22:46
@athitten athitten force-pushed the athitten/add_context_logits branch from 630966e to 5fb26b1 Compare January 6, 2025 23:09
@athitten athitten force-pushed the athitten/eval_server_ready_check branch from 4bd6b47 to 8cbaef3 Compare January 7, 2025 00:42
Base automatically changed from athitten/eval_server_ready_check to main January 7, 2025 19:20
@athitten athitten force-pushed the athitten/add_context_logits branch from cd4e7e6 to 900ad07 Compare January 8, 2025 19:31
@athitten athitten changed the title Use context_logits for eval accuracy calculation Add context_logits for eval accuracy calculation in case of multi token prediction tasks Jan 8, 2025
Comment on lines +46 to +48
def _generate_tokens_logits(
self, payload, single_prediction_token, return_text: bool = False, return_logits: bool = False
):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
@athitten athitten force-pushed the athitten/add_context_logits branch from 07925a5 to f0c3cb3 Compare January 9, 2025 04:23
@athitten athitten marked this pull request as ready for review January 9, 2025 06:16
Copy link
Collaborator

@oyilmaz-nvidia oyilmaz-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but can you please run the examples here https://docs.nvidia.com/nemo-framework/user-guide/latest/deployment/llm/optimized/tensorrt_llm.html and make sure nothing is broken?

@github-actions
Copy link
Contributor

github-actions bot commented Jan 9, 2025

[🤖]: Hi @athitten 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully

So it might be time to merge this PR or get some approvals

I'm just a bot so I'll leave it you what to do next.

//cc @pablo-garay @ko3n1g

hemildesai
hemildesai previously approved these changes Jan 9, 2025
@hemildesai
Copy link
Collaborator

LGTM for nemo/collections/llm related changes.

@athitten
Copy link
Collaborator Author

athitten commented Jan 10, 2025

Looks good but can you please run the examples here https://docs.nvidia.com/nemo-framework/user-guide/latest/deployment/llm/optimized/tensorrt_llm.html and make sure nothing is broken?

Hi @oyilmaz-nvidia ran the scripts you pointed with HF llama3-8b converted to nemo2. No errors, everything worked fine. Here's the output

python scripts/deploy/nlp/deploy_triton.py --nemo_check
point /workspace/hf_llama3_8b_nemo2_new.nemo --model_type 'llama' --triton_model_name 'llama3-8b' -
-tensor_parallelism_size 1

Output of deploy_triton.py

| cuda_memory_pool_byte_size{1}    | 67108864                                 |
| min_supported_compute_capability | 6.0                                      |
| strict_readiness                 | 1                                        |
| exit_timeout                     | 30                                       |
| cache_enabled                    | 0                                        |
+----------------------------------+------------------------------------------+

I0110 00:42:14.989638 1098739 grpc_server.cc:2558] "Started GRPCInferenceService at 0.0.0.0:8001"
I0110 00:42:14.989788 1098739 http_server.cc:4713] "Started HTTPService at 0.0.0.0:8000"
I0110 00:42:15.030635 1098739 http_server.cc:362] "Started Metrics Service at 0.0.0.0:8002"
E0110 00:42:15.571184 1098739 model_repository_manager.cc:470] "Failed to set config modification time: model_config_content_name_ is empty"
I0110 00:42:15.571553 1098739 model_lifecycle.cc:472] "loading: llama3-8b:1"
I0110 00:42:17.003485 1098739 python_be.cc:2249] "TRITONBACKEND_ModelInstanceInitialize: llama3-8b_0_0 (CPU device 0)"
I0110 00:42:17.484886 1098739 model_lifecycle.cc:839] "successfully loaded 'llama3-8b'"
[01/10/2025-00:42:17] Model serving on Triton is will be started.

Inference request:

python scripts/deploy/nlp/query.py -mn 'llama3-8b' -p "Hi, how are you?" -mol 20

Output:
Screen Shot 2025-01-09 at 4 48 53 PM

Screen Shot 2025-01-09 at 4 49 37 PM

oyilmaz-nvidia
oyilmaz-nvidia previously approved these changes Jan 14, 2025
Copy link
Collaborator

@oyilmaz-nvidia oyilmaz-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@oyilmaz-nvidia oyilmaz-nvidia enabled auto-merge (squash) January 14, 2025 19:06
athitten and others added 6 commits January 16, 2025 18:13
Uses bool generation_logits_available as inputs dict does not contain it

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: athitten <athitten@users.noreply.github.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
pre-commit-ci bot and others added 5 commits January 16, 2025 18:13
Signed-off-by: athitten <athitten@users.noreply.github.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
@github-actions
Copy link
Contributor

beep boop 🤖: 🙏 The following files have warnings. In case you are familiar with these, please try helping us to improve the code base.


Your code was analyzed with PyLint. The following annotations have been identified:

************* Module nemo.collections.llm.api
nemo/collections/llm/api.py:368:0: C0301: Line too long (121/119) (line-too-long)
nemo/collections/llm/api.py:369:0: C0301: Line too long (120/119) (line-too-long)
nemo/collections/llm/api.py:445:0: C0301: Line too long (130/119) (line-too-long)
nemo/collections/llm/api.py:572:0: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/llm/api.py:15:0: W0611: Unused import os (unused-import)
************* Module nemo.deploy.nlp.query_llm
nemo/deploy/nlp/query_llm.py:29:0: C0115: Missing class docstring (missing-class-docstring)
************* Module nemo.export.trt_llm.tensorrt_llm_run
nemo/export/trt_llm/tensorrt_llm_run.py:506:0: C0301: Line too long (125/119) (line-too-long)
nemo/export/trt_llm/tensorrt_llm_run.py:510:0: C0301: Line too long (136/119) (line-too-long)
nemo/export/trt_llm/tensorrt_llm_run.py:514:0: C0301: Line too long (123/119) (line-too-long)
nemo/export/trt_llm/tensorrt_llm_run.py:557:0: C0301: Line too long (181/119) (line-too-long)
nemo/export/trt_llm/tensorrt_llm_run.py:844:0: C0301: Line too long (153/119) (line-too-long)
nemo/export/trt_llm/tensorrt_llm_run.py:524:0: C0116: Missing function or method docstring (missing-function-docstring)
nemo/export/trt_llm/tensorrt_llm_run.py:533:0: C0116: Missing function or method docstring (missing-function-docstring)
nemo/export/trt_llm/tensorrt_llm_run.py:591:0: C0116: Missing function or method docstring (missing-function-docstring)
nemo/export/trt_llm/tensorrt_llm_run.py:33:0: W0611: Unused Mapping imported from tensorrt_llm.mapping (unused-import)
************* Module nemo.export.vllm_exporter
nemo/export/vllm_exporter.py:38:0: C0116: Missing function or method docstring (missing-function-docstring)
nemo/export/vllm_exporter.py:430:4: C0116: Missing function or method docstring (missing-function-docstring)

-----------------------------------
Your code has been rated at 9.89/10

Mitigation guide:

  • Add sensible and useful docstrings to functions and methods
  • For trivial methods like getter/setters, consider adding # pylint: disable=C0116 inside the function itself
  • To disable multiple functions/methods at once, put a # pylint: disable=C0116 before the first and a # pylint: enable=C0116 after the last.

By applying these rules, we reduce the occurance of this message in future.

Thank you for improving NeMo's documentation!

@github-actions
Copy link
Contributor

[🤖]: Hi @athitten 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully

So it might be time to merge this PR or get some approvals

I'm just a bot so I'll leave it you what to do next.

//cc @pablo-garay @ko3n1g

@oyilmaz-nvidia oyilmaz-nvidia merged commit ca4e4f0 into main Jan 17, 2025
389 of 393 checks passed
@oyilmaz-nvidia oyilmaz-nvidia deleted the athitten/add_context_logits branch January 17, 2025 19:49
abhinavg4 pushed a commit that referenced this pull request Jan 30, 2025
…en prediction tasks (#11753)

* Add server ready check before evaluation

Uses bool generation_logits_available as inputs dict does not contain it

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Add context logits

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Remove max_tokens_to_generate and add more comments

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Apply isort and black reformatting

Signed-off-by: athitten <athitten@users.noreply.github.com>

* Get context_logits for multi token prediction tasks

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Fix bug with single/multi token condition check

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply isort and black reformatting

Signed-off-by: athitten <athitten@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bugfix with output_context_logits

Signed-off-by: Abhishree <abhishreetm@gmail.com>

---------

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: athitten <athitten@users.noreply.github.com>
Co-authored-by: athitten <athitten@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
youngeunkwon0405 pushed a commit to youngeunkwon0405/NeMo that referenced this pull request Feb 10, 2025
…en prediction tasks (NVIDIA-NeMo#11753)

* Add server ready check before evaluation

Uses bool generation_logits_available as inputs dict does not contain it

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Add context logits

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Remove max_tokens_to_generate and add more comments

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Apply isort and black reformatting

Signed-off-by: athitten <athitten@users.noreply.github.com>

* Get context_logits for multi token prediction tasks

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Fix bug with single/multi token condition check

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply isort and black reformatting

Signed-off-by: athitten <athitten@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bugfix with output_context_logits

Signed-off-by: Abhishree <abhishreetm@gmail.com>

---------

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: athitten <athitten@users.noreply.github.com>
Co-authored-by: athitten <athitten@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants