diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 467c829..9beadcc 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -1,10 +1,10 @@ from __future__ import annotations import json +import re from datetime import datetime, timezone from textwrap import wrap from typing import Any, Literal -import re import click from click.core import ParameterSource # type: ignore[attr-defined] @@ -13,17 +13,17 @@ from together import Together from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX +from together.types.finetune import ( + DownloadCheckpointType, + FinetuneEventType, + FinetuneTrainingLimits, +) from together.utils import ( finetune_price_to_dollars, + format_timestamp, log_warn, log_warn_once, parse_timestamp, - format_timestamp, -) -from together.types.finetune import ( - DownloadCheckpointType, - FinetuneTrainingLimits, - FinetuneEventType, ) @@ -340,9 +340,9 @@ def list(ctx: click.Context) -> None: "Model Output Name": "\n".join(wrap(i.output_name or "", width=30)), "Status": i.status, "Created At": i.created_at, - "Price": f"""${finetune_price_to_dollars( - float(str(i.total_price)) - )}""", # convert to string for mypy typing + "Price": f"""${ + finetune_price_to_dollars(float(str(i.total_price))) + }""", # convert to string for mypy typing } ) table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 82a7021..430cd8d 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -77,7 +77,7 @@ def create_finetune_request( wandb_base_url: str | None = None, wandb_project_name: str | None = None, wandb_name: str | None = None, - train_on_inputs: bool | Literal["auto"] = "auto", + train_on_inputs: bool | Literal["auto"] | None = None, training_method: str = "sft", dpo_beta: float | None = None, from_checkpoint: str | None = None, @@ -162,6 +162,18 @@ def create_finetune_request( f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" ) + if train_on_inputs is not None and training_method != "sft": + raise ValueError("train_on_inputs is only supported for SFT training") + + if train_on_inputs is None and training_method == "sft": + log_warn_once( + "train_on_inputs is not set for SFT training, it will be set to 'auto' automatically" + ) + train_on_inputs = "auto" + + if dpo_beta is not None and training_method != "dpo": + raise ValueError("dpo_beta is only supported for DPO training") + lr_scheduler: FinetuneLRScheduler if lr_scheduler_type == "cosine": if scheduler_num_cycles <= 0.0: @@ -179,7 +191,9 @@ def create_finetune_request( lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), ) - training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() + training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT( + train_on_inputs=train_on_inputs + ) if training_method == "dpo": training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) @@ -202,7 +216,6 @@ def create_finetune_request( wandb_base_url=wandb_base_url, wandb_project_name=wandb_project_name, wandb_name=wandb_name, - train_on_inputs=train_on_inputs, training_method=training_method_cls, from_checkpoint=from_checkpoint, ) @@ -307,7 +320,7 @@ def create( wandb_name: str | None = None, verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, - train_on_inputs: bool | Literal["auto"] = "auto", + train_on_inputs: bool | Literal["auto"] | None = None, training_method: str = "sft", dpo_beta: float | None = None, from_checkpoint: str | None = None, @@ -352,12 +365,12 @@ def create( Defaults to False. model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning. Defaults to None. - train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data. + train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data. "auto" will automatically determine whether to mask the inputs based on the data format. For datasets with the "text" field (general format), inputs will not be masked. For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. - Defaults to "auto". + Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request). training_method (str, optional): Training method. Defaults to "sft". Supported methods: "sft", "dpo". dpo_beta (float, optional): DPO beta parameter. Defaults to None. @@ -695,7 +708,7 @@ async def create( wandb_name: str | None = None, verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, - train_on_inputs: bool | Literal["auto"] = "auto", + train_on_inputs: bool | Literal["auto"] | None = None, training_method: str = "sft", dpo_beta: float | None = None, from_checkpoint: str | None = None, @@ -745,7 +758,7 @@ async def create( For datasets with the "text" field (general format), inputs will not be masked. For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. - Defaults to "auto". + Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request). training_method (str, optional): Training method. Defaults to "sft". Supported methods: "sft", "dpo". dpo_beta (float, optional): DPO beta parameter. Defaults to None. diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 6325ce5..195ee35 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -3,7 +3,7 @@ from enum import Enum from typing import List, Literal -from pydantic import StrictBool, Field, field_validator +from pydantic import Field, StrictBool, field_validator from together.types.abstract import BaseModel from together.types.common import ( @@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod): """ method: Literal["sft"] = "sft" + train_on_inputs: StrictBool | Literal["auto"] = "auto" class TrainingMethodDPO(TrainingMethod): @@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel): wandb_name: str | None = None # training type training_type: FullTrainingType | LoRATrainingType | None = None - # train on inputs - train_on_inputs: StrictBool | Literal["auto"] = "auto" # training method training_method: TrainingMethodSFT | TrainingMethodDPO = Field( default_factory=TrainingMethodSFT diff --git a/tests/unit/test_finetune_resources.py b/tests/unit/test_finetune_resources.py index f7acdbc..6cbc634 100644 --- a/tests/unit/test_finetune_resources.py +++ b/tests/unit/test_finetune_resources.py @@ -247,3 +247,32 @@ def test_bad_training_method(): training_file=_TRAINING_FILE, training_method="NON_SFT", ) + + +@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None]) +def test_train_on_inputs_for_sft(train_on_inputs): + request = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + training_method="sft", + train_on_inputs=train_on_inputs, + ) + assert request.training_method.method == "sft" + if isinstance(train_on_inputs, bool): + assert request.training_method.train_on_inputs is train_on_inputs + else: + assert request.training_method.train_on_inputs == "auto" + + +def test_train_on_inputs_not_supported_for_dpo(): + with pytest.raises( + ValueError, match="train_on_inputs is only supported for SFT training" + ): + _ = create_finetune_request( + model_limits=_MODEL_LIMITS, + model=_MODEL_NAME, + training_file=_TRAINING_FILE, + training_method="dpo", + train_on_inputs=True, + )