Skip to content

Migrate train_on_inputs to sft-specific params #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/together/cli/api/finetune.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting only

Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import json
import re
from datetime import datetime, timezone
from textwrap import wrap
from typing import Any, Literal
import re

import click
from click.core import ParameterSource # type: ignore[attr-defined]
Expand All @@ -13,17 +13,17 @@

from together import Together
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
from together.types.finetune import (
DownloadCheckpointType,
FinetuneEventType,
FinetuneTrainingLimits,
)
from together.utils import (
finetune_price_to_dollars,
format_timestamp,
log_warn,
log_warn_once,
parse_timestamp,
format_timestamp,
)
from together.types.finetune import (
DownloadCheckpointType,
FinetuneTrainingLimits,
FinetuneEventType,
)


Expand Down Expand Up @@ -340,9 +340,9 @@ def list(ctx: click.Context) -> None:
"Model Output Name": "\n".join(wrap(i.output_name or "", width=30)),
"Status": i.status,
"Created At": i.created_at,
"Price": f"""${finetune_price_to_dollars(
float(str(i.total_price))
)}""", # convert to string for mypy typing
"Price": f"""${
finetune_price_to_dollars(float(str(i.total_price)))
}""", # convert to string for mypy typing
}
)
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
Expand Down
29 changes: 21 additions & 8 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_finetune_request(
wandb_base_url: str | None = None,
wandb_project_name: str | None = None,
wandb_name: str | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
train_on_inputs: bool | Literal["auto"] | None = None,
training_method: str = "sft",
dpo_beta: float | None = None,
from_checkpoint: str | None = None,
Expand Down Expand Up @@ -162,6 +162,18 @@ def create_finetune_request(
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
)

if train_on_inputs is not None and training_method != "sft":
raise ValueError("train_on_inputs is only supported for SFT training")

if train_on_inputs is None and training_method == "sft":
log_warn_once(
"train_on_inputs is not set for SFT training, it will be set to 'auto' automatically"
)
train_on_inputs = "auto"

if dpo_beta is not None and training_method != "dpo":
raise ValueError("dpo_beta is only supported for DPO training")

lr_scheduler: FinetuneLRScheduler
if lr_scheduler_type == "cosine":
if scheduler_num_cycles <= 0.0:
Expand All @@ -179,7 +191,9 @@ def create_finetune_request(
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT(
train_on_inputs=train_on_inputs
)
if training_method == "dpo":
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)

Expand All @@ -202,7 +216,6 @@ def create_finetune_request(
wandb_base_url=wandb_base_url,
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
train_on_inputs=train_on_inputs,
training_method=training_method_cls,
from_checkpoint=from_checkpoint,
)
Expand Down Expand Up @@ -307,7 +320,7 @@ def create(
wandb_name: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
train_on_inputs: bool | Literal["auto"] | None = None,
training_method: str = "sft",
dpo_beta: float | None = None,
from_checkpoint: str | None = None,
Expand Down Expand Up @@ -352,12 +365,12 @@ def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data.
"auto" will automatically determine whether to mask the inputs based on the data format.
For datasets with the "text" field (general format), inputs will not be masked.
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
Expand Down Expand Up @@ -695,7 +708,7 @@ async def create(
wandb_name: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
train_on_inputs: bool | Literal["auto"] = "auto",
train_on_inputs: bool | Literal["auto"] | None = None,
training_method: str = "sft",
dpo_beta: float | None = None,
from_checkpoint: str | None = None,
Expand Down Expand Up @@ -745,7 +758,7 @@ async def create(
For datasets with the "text" field (general format), inputs will not be masked.
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
(Instruction format), inputs will be masked.
Defaults to "auto".
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
training_method (str, optional): Training method. Defaults to "sft".
Supported methods: "sft", "dpo".
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
Expand Down
5 changes: 2 additions & 3 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import List, Literal

from pydantic import StrictBool, Field, field_validator
from pydantic import Field, StrictBool, field_validator

from together.types.abstract import BaseModel
from together.types.common import (
Expand Down Expand Up @@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod):
"""

method: Literal["sft"] = "sft"
train_on_inputs: StrictBool | Literal["auto"] = "auto"


class TrainingMethodDPO(TrainingMethod):
Expand Down Expand Up @@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel):
wandb_name: str | None = None
# training type
training_type: FullTrainingType | LoRATrainingType | None = None
# train on inputs
train_on_inputs: StrictBool | Literal["auto"] = "auto"
# training method
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
default_factory=TrainingMethodSFT
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_finetune_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,32 @@ def test_bad_training_method():
training_file=_TRAINING_FILE,
training_method="NON_SFT",
)


@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None])
def test_train_on_inputs_for_sft(train_on_inputs):
request = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
training_method="sft",
train_on_inputs=train_on_inputs,
)
assert request.training_method.method == "sft"
if isinstance(train_on_inputs, bool):
assert request.training_method.train_on_inputs is train_on_inputs
else:
assert request.training_method.train_on_inputs == "auto"


def test_train_on_inputs_not_supported_for_dpo():
with pytest.raises(
ValueError, match="train_on_inputs is only supported for SFT training"
):
_ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
training_method="dpo",
train_on_inputs=True,
)