Description
Authors: @KeijiBranshi @rosbo @mrisdal @neshdev at.bflynn
Summary
This RFC proposes extending torchtune to support loading pre-trained and fine-tuned model weights directly from the Kaggle Model Hub. This integration aims to expand the accessibility of models within torchtune and contribute to the adoption of both PyTorch/torchtune and the Kaggle Model Hub by streamlining the experience for Kaggle users.
Motivation
This proposal aligns with PyTorch's objective of integrating with partner platforms to increase Torchtune adoption (KR 3.2). By adding support for Kaggle, we can:
-
Increase Model Accessibility: Provide Torchtune users with a wider selection of pre-trained models and community-shared fine-tuned weights, fostering a more diverse model ecosystem.
-
Community Engagement: This integration can foster collaboration between the PyTorch and Kaggle communities, leading to increased contributions and a more diverse model ecosystem.
Other Potential Benefits
- Streamline Kaggle Competition Workflow: Enable seamless torchtune model loading within Kaggle competition notebooks, which have internet access restrictions. This eliminates the need for workarounds currently required to use PyTorch models in competitions.
- Deeper Kaggle Notebook Integration: Using kagglehub allows for better integration with Kaggle Notebooks, enabling features like automatic model detection and UI enhancements within the notebook environment.
Prior Art
Similar functionality for loading models from various hubs exists in other deep learning libraries. Keras, for example, provides a unified mechanism for loading models from both Hugging Face and Kaggle using URI schemes (see the Keras documentation).
Proposed Implementation
We propose extending torchtune's model loading mechanism to recognize and handle Kaggle Model Hub URIs. This will involve the following:
-
URI Scheme Recognition: Torchtune can be updated to recognize model URIs using the
kaggle://
scheme. While Hugging Face will remain the default source for models, we could also add support for explicit Hugging Face URIs using thehf://
scheme for increased clarity. -
Kaggle Hub Integration: Leverage the kagglehub Python library to handle the download and upload of model weights to and from the Kaggle Model Hub.
Using the above, we would modify torchtune's model loading logic to:
- Detect the URI scheme (
kaggle://
orhf://
). - Utilize kagglehub for downloading weights from Kaggle when a
kaggle://
URI is provided. - Maintain the existing Hugging Face integration for models without a URI scheme or those explicitly using
hf://
.
Example Usage:
Users will be able to download a model from Kaggle using a command like:
tune download kaggle://metaresearch/llama-3.2/pyTorch/3b \
--output-dir /tmp/llama-3.2-3b \
--kaggle-username <KAGGLE_USERNAME> \
--kaggle-api-key <KAGGLE_API_KEY>
Considerations
-
Backward Compatibility: This change should not affect existing functionality with Hugging Face models.
-
Dependencies: Torchtune will need to add kagglehub as a dependency. Would the introduction of fsspec for more general URI scheme handling be a desirable enhancement, even if it adds complexity?
-
--output-dir
Argument: Since kagglehub utilizes a default cache folder, should the--output-dir
argument be optional or required for Kaggle models? What are the preferred behaviors and potential implications of each approach? -
Documentation: We are willing to contribute to the torchtune documentation to include instructions and examples for using Kaggle Model Hub URIs. Guidance on the documentation update process and procedures would be greatly appreciated.
-
Testing: Develop comprehensive tests to ensure the correct functionality of Kaggle model loading and compatibility with existing features.
Call for Feedback
We’d love feedback from the PyTorch community on this proposal. Please share your thoughts, suggestions, and any potential concerns you may have.
Happy modeling,
The Kaggle Team