From 642badc497511d7bcb9b6a3cf8aea38e50a37f01 Mon Sep 17 00:00:00 2001 From: Hynek Kydlicek Date: Thu, 14 Nov 2024 00:43:47 +0100 Subject: [PATCH 1/3] implement tranlsation prompt --- src/lighteval/tasks/templates/translation.py | 145 +++++++++++++++++++ tests/tasks/templates/test_translation.py | 91 ++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 src/lighteval/tasks/templates/translation.py create mode 100644 tests/tasks/templates/test_translation.py diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py new file mode 100644 index 000000000..50ba97a0d --- /dev/null +++ b/src/lighteval/tasks/templates/translation.py @@ -0,0 +1,145 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import Callable + +from langcodes import standardize_tag +from typing_extensions import NotRequired, TypedDict + +from lighteval.tasks.templates.continuation import get_continuation_prompt_function +from lighteval.tasks.templates.multichoice import create_adapter_from_dict +from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation +from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS +from lighteval.utils.language import Language +from lighteval.utils.utils import as_list + + +TRANSLATION_CONTEXT = ( + "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}{sentence_space}" +) + + +# Defined for type hinting only +class TranslationInput(TypedDict): + """ + Input for the Translation task. + Args: + source_text: The source text to be translated + target_text: The target text to be translated + instruction (optional): The instruction of the Translation task (e.g. Translate the following text to Turkish) + """ + + source_text: str + target_text: str | list[str] + gold_idx: NotRequired[int | list[int]] + instruction: NotRequired[str] + + +class TranslationAdapter(TypedDict): + """ + Adapter for mapping from the dataset row into the TranslationInput format. + Args: + source_text: Column name in the row that contains the source text to be translated + target_text: Column name in the row that contains the target text to be translated + instruction (optional): Column name in the row that contains the instruction of the task (e.g. Translate the following text to Turkish) + """ + + source_text: str + target_text: str + gold_idx: NotRequired[int | list[int]] + instruction: NotRequired[str] + + +def get_translation_prompt_function( + source_language: Language, + target_language: Language, + adapter: Callable[[dict], TranslationInput | None] | TranslationAdapter, + formulation: Formulation = MCFFormulation(), +): + """ + Create a templated prompt function for a Translation task. + Example tasks: + - WMT2016 + - WMT2017 + + Format: + *CF* + EN: How are you? TR: | Nasılsın? + + *Hybrid* + EN: How are you? TR: + A. Nasılsın? + B. Jak se máš? + Answer: | Nasılsın?/Jak se máš? + + *MCF* + EN: How are you? TR: + A. Nasılsın? + B. Jak se máš? + Answer: | A/B + + Args: + adapter (Callable[[dict], TranslationInput] | TranslationAdapter): Either a function that takes a dataset row and returns a TranslationInput, or a dictionary with keys corresponding to the field names in the dataset row. + Note: Both TranslationAdapter and TranslationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class! + formulation (Formulation, optional): The formulation to use for the task. Defaults to MCFFormulation(). + Returns: + Callable: A function that generates Translation prompts based on the given parameters. + """ + adapter_fn = create_adapter_from_dict(adapter) + continuation_prompt_fn = get_continuation_prompt_function( + Language.ENGLISH, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation + ) + translation_literals = TRANSLATION_LITERALS[source_language] + + source_label_string = standardize_tag(source_language.value).upper() + target_label_string = standardize_tag(target_language.value).upper() + + def translation_prompt( + line: dict, + task_name: str, + ): + input_data = adapter_fn(line) + if input_data is None: + return None + + context = TRANSLATION_CONTEXT.format( + source_label=source_label_string, + source_text=input_data["source_text"], + target_label=target_label_string, + target_text=input_data["target_text"], + colon=translation_literals.colon, + sentence_space=translation_literals.sentence_space, + ) + + continuations = as_list(input_data["target_text"]) + + return continuation_prompt_fn( + { + "instruction": input_data.get("instruction", ""), + "context": context, + "continuations": continuations, + "gold_idx": input_data.get("gold_idx", list(range(len(continuations)))), + }, + task_name, + ) + + return translation_prompt diff --git a/tests/tasks/templates/test_translation.py b/tests/tasks/templates/test_translation.py new file mode 100644 index 000000000..a37f8ed78 --- /dev/null +++ b/tests/tasks/templates/test_translation.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from lighteval.tasks.templates.translation import get_translation_prompt_function +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation +from lighteval.utils.language import Language + + +def test_translation_prompt_cf(): + """ + Tests that translation prompt function works correctly for CF formulation. + """ + test_input = { + "source_text": "Ahoj, jak se máš?", + "target_text": "Bonjour, comment allez-vous?", + } + + prompt_fn = get_translation_prompt_function( + source_language=Language.CZECH, + target_language=Language.FRENCH, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + }, + formulation=CFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + assert doc.query == "CS: Ahoj, jak se máš? FR:" + assert doc.unconditioned_query == "" + assert doc.choices == [" Bonjour, comment allez-vous?"] + assert doc.gold_index == [0] + + +def test_translation_prompt_mcf(): + """ + Tests that translation prompt function works correctly for MCF formulation. + """ + test_input = { + "source_text": "Ahoj, jak se máš?", + "target_text": ["Bonjour, comment allez-vous?", "Ciao, come stai?"], + } + + prompt_fn = get_translation_prompt_function( + source_language=Language.CZECH, + target_language=Language.FRENCH, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + "gold_idx": 0, + }, + formulation=MCFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + assert ( + doc.query + == """\ +CS: Ahoj, jak se máš? FR: + A. Bonjour, comment allez-vous? + B. Ciao, come stai? +Answer:\ +""" + ) + assert doc.unconditioned_query == "Answer:" + assert doc.choices == [" A", " B"] + assert doc.gold_index == [0] From bfdc019c16d5a510468a3a667047888ea29b2b12 Mon Sep 17 00:00:00 2001 From: Hynek Kydlicek Date: Thu, 14 Nov 2024 00:54:08 +0100 Subject: [PATCH 2/3] add small coment about tranlsation prompt --- src/lighteval/tasks/templates/translation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py index 50ba97a0d..602a85e8f 100644 --- a/src/lighteval/tasks/templates/translation.py +++ b/src/lighteval/tasks/templates/translation.py @@ -33,6 +33,10 @@ from lighteval.utils.utils import as_list +# Template chosen so that it's not very language-dependent, as it's not clear whether one should use the target or source language. +# It's also the best template based on https://arxiv.org/pdf/2301.07069. + + TRANSLATION_CONTEXT = ( "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}{sentence_space}" ) From f9ba1fa7b69082e18bc118aac84f94e983c7be98 Mon Sep 17 00:00:00 2001 From: Hynek Kydlicek Date: Fri, 15 Nov 2024 16:35:30 +0100 Subject: [PATCH 3/3] change formatting to reformat language dependant parts --- src/lighteval/tasks/templates/continuation.py | 10 ++++++- src/lighteval/tasks/templates/translation.py | 27 ++++++++++------- tests/tasks/templates/test_translation.py | 29 +++++++++++++++++++ 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/lighteval/tasks/templates/continuation.py b/src/lighteval/tasks/templates/continuation.py index 84c112305..6435fc8f2 100644 --- a/src/lighteval/tasks/templates/continuation.py +++ b/src/lighteval/tasks/templates/continuation.py @@ -86,6 +86,7 @@ def get_continuation_prompt_function( language: Language, adapter: Callable[[dict], ContinuationInput | None] | ContinuationDictAdapter, formulation: Formulation = MCFFormulation(), + fix_formatting: bool = True, ): """ Create a templated prompt function for a Continuation task. @@ -118,6 +119,7 @@ def get_continuation_prompt_function( adapter (Callable[[dict], ContinuationInput] | ContinuationDictAdapter): Either a function that takes a dataset row and returns a ContinuationInput, or a dictionary with keys corresponding to the field names in the dataset row. Note: Both ContinuationDictAdapter and ContinuationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class! formulation (Formulation, optional): The formulation (MCF/Hybrid/CF) to use for the task. Defaults to MCFFormulation(). + fix_formatting (bool, optional): Whether to fix the formatting of the text by capitalizing and fixing punctuation based on language. If False, the text will be used as-is. Defaults to True. Returns: Callable: A function that generates Continuation prompt based on the given parameters. """ @@ -132,10 +134,16 @@ def prepare_prompt(line: dict): instruction_val = cont_input.get("instruction") instruction = f"{instruction_val}\n" if instruction_val else "" - context = f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}" + context = ( + f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}" + if fix_formatting + else cont_input["context"] + ) continuations = [ fix_capitalization(context, fix_ending_punct(continuation, translation_literals), translation_literals) + if fix_formatting + else continuation for continuation in cont_input["continuations"] ] diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py index 602a85e8f..c90b99e01 100644 --- a/src/lighteval/tasks/templates/translation.py +++ b/src/lighteval/tasks/templates/translation.py @@ -27,6 +27,7 @@ from lighteval.tasks.templates.continuation import get_continuation_prompt_function from lighteval.tasks.templates.multichoice import create_adapter_from_dict +from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS from lighteval.utils.language import Language @@ -37,9 +38,7 @@ # It's also the best template based on https://arxiv.org/pdf/2301.07069. -TRANSLATION_CONTEXT = ( - "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}{sentence_space}" -) +TRANSLATION_CONTEXT = "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}" # Defined for type hinting only @@ -110,9 +109,13 @@ def get_translation_prompt_function( """ adapter_fn = create_adapter_from_dict(adapter) continuation_prompt_fn = get_continuation_prompt_function( - Language.ENGLISH, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation + Language.ENGLISH, + {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, + formulation, + fix_formatting=False, ) - translation_literals = TRANSLATION_LITERALS[source_language] + source_translation_literals = TRANSLATION_LITERALS[source_language] + target_translation_literals = TRANSLATION_LITERALS[target_language] source_label_string = standardize_tag(source_language.value).upper() target_label_string = standardize_tag(target_language.value).upper() @@ -125,16 +128,20 @@ def translation_prompt( if input_data is None: return None + source_text = capitalize(fix_ending_punct(input_data["source_text"], source_translation_literals)) + context = TRANSLATION_CONTEXT.format( source_label=source_label_string, - source_text=input_data["source_text"], + source_text=source_text, target_label=target_label_string, - target_text=input_data["target_text"], - colon=translation_literals.colon, - sentence_space=translation_literals.sentence_space, + colon=":", + sentence_space=" ", ) - continuations = as_list(input_data["target_text"]) + continuations = [ + capitalize(fix_ending_punct(text, target_translation_literals)) + for text in as_list(input_data["target_text"]) + ] return continuation_prompt_fn( { diff --git a/tests/tasks/templates/test_translation.py b/tests/tasks/templates/test_translation.py index a37f8ed78..eab59cf18 100644 --- a/tests/tasks/templates/test_translation.py +++ b/tests/tasks/templates/test_translation.py @@ -89,3 +89,32 @@ def test_translation_prompt_mcf(): assert doc.unconditioned_query == "Answer:" assert doc.choices == [" A", " B"] assert doc.gold_index == [0] + + +def test_translation_prompt_cf_formatting(): + """ + Tests that translation prompt function works correctly for CF formulation with formatting. + """ + test_input = { + "source_text": "How are you?", + "target_text": ["你好吗?"], + } + + prompt_fn = get_translation_prompt_function( + source_language=Language.ENGLISH, + target_language=Language.CHINESE, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + "gold_idx": 0, + }, + formulation=CFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + assert doc.query == "EN: How are you? ZH:" + assert doc.unconditioned_query == "" + assert doc.choices == [" 你好吗?"] + assert doc.gold_index == [0]