Skip to content

Surface docstrings in interface functions #18

@RobotSail

Description

@RobotSail

In the current version of training hub, we have detailed docstrings of algorithms in their Algorithm class, but these are not surfaced through their interface functions.

E.g.:

class OSFTAlgorithm(Algorithm):
    """
    Implements the Orthogonal Subspace Fine-Tuning (OSFT) algorithm,
    based on Nayak et al. (2025), arXiv:2504.07097

    This algorithm allows for continual training of pre-trained or instruction-tuned
    models without the need of a supplementary dataset to maintain the distribution
    of the original model/dataset that was trained.
    """

    def __init__(self, backend: Backend, **kwargs) -> None:
        self.backend = backend
        self.kwargs = kwargs

    def train(
        self,
        model_path: str,
        ...
        **kwargs,
    ) -> any:
        """
        This algorithm implements Continual Training using the OSFT algorithm
        with the mini-trainer backend.

        **Note:** The OSFT algorithm does not reduce the memory requirement when compared,
        to SFT, but it significantly reduces the data requirement for customizing an instruction-tuned
        model compared to SFT.

        **Note:**
            While all values of `unfreeze_rank_ratio` are valid, in practice you will seldom
            need values greater than 0.5 for general continual-learning regimes.
        
        Arguments:
            model_path (str): Local path or HuggingFace model ID to be used for fine-tuning.
            data_path (str):
                Path to the training data. When `use_processed_dataset` is True,
                this is the path to the processed dataset. When `use_processed_dataset` is False,
                this is the path to the original dataset.
            unfreeze_rank_ratio (float):
                Controls the amount that each matrix is unfrozen during OSFT. 
                Valid values are between 0.0 and 1.0.
            effective_batch_size (int): Effective batch size for training.
            max_tokens_per_gpu (int):
                The maximum number of tokens placed on a single GPU for training.
	    
        ...
        """

Compared to the implementation function, which the users interact with:

sft(
    model_path: str,
    data_path: str,
    unfreeze_rank_ratio: float,
    effective_batch_size: int,
    max_tokens_per_gpu: int,
    max_seq_len: int,
    learning_rate: float,
    ckpt_output_dir: str,
    data_output_dir: str | None = None,
    backend: str = "mini-trainer",
    # Optional parameters
    target_patterns: list[str] | None = None,
    seed: int | None = None,
    use_liger: bool | None = None,
    use_processed_dataset: bool | None = None,
    unmask_messages: bool | None = None,
    lr_scheduler: str | None = None,
    warmup_steps: int | None = None,
    lr_scheduler_kwargs: dict[str, str] | None = None,
    checkpoint_at_epoch: bool | None = None,
    save_final_checkpoint: bool | None = None,
    num_epochs: int | None = None,
    # Torchrun parameters for multi-node support
    nproc_per_node: Literal['auto', 'gpu'] | int | None = None,
    nnodes: int | None = None,
    node_rank: int | None = None,
    rdzv_id: str | int | None = None,
    rdzv_endpoint: str | None = None,
    master_port: str | None = None,
    master_addr: str | None = None,
    **kwargs
) -> any:
    from . import create_algorithm
    
    algorithm = create_algorithm('osft', backend)
    return algorithm.train(
        model_path=model_path,
        data_path=data_path,
        ckpt_output_dir=ckpt_output_dir,
        data_output_dir=data_output_dir,
        unfreeze_rank_ratio=unfreeze_rank_ratio,
        effective_batch_size=effective_batch_size,
        max_tokens_per_gpu=max_tokens_per_gpu,
        max_seq_len=max_seq_len,
        learning_rate=learning_rate,
        target_patterns=target_patterns,
        seed=seed,
        use_liger=use_liger,
        use_processed_dataset=use_processed_dataset,
        unmask_messages=unmask_messages,
        lr_scheduler=lr_scheduler,
        warmup_steps=warmup_steps,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
        checkpoint_at_epoch=checkpoint_at_epoch,
        save_final_checkpoint=save_final_checkpoint,
        num_epochs=num_epochs,
        nproc_per_node=nproc_per_node,
        nnodes=nnodes,
        node_rank=node_rank,
        rdzv_id=rdzv_id,
        rdzv_endpoint=rdzv_endpoint,
        master_port=master_port,
        master_addr=master_addr,
        **kwargs
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions