-
Notifications
You must be signed in to change notification settings - Fork 6k
DiT Pipeline #1806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DiT Pipeline #1806
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work @kashif, thanks a lot for quickly integrating this! I left a few comments mostly nits. Wondering if it's possible to leverage the existing Transformer2DModel
for this, or if we should leave DiT as a new model class. Think given the performance of the model, there will be many more models like this, so I'm leaning toward the latter. Curious to hear what you think @patrickvonplaten @pcuenca @anton-l @williamberman @yiyixuxu
Need to address few more things before we could merge
- Check if it works with the existing scheduler or needs any changes.
- Add tests
- Add docs.
src/diffusers/models/dit.py
Outdated
class DiT(ModelMixin, ConfigMixin): | ||
""" | ||
Diffusion model with a Transformer backbone. | ||
""" | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
input_size=32, | ||
patch_size=2, | ||
in_channels=4, | ||
hidden_size=1152, | ||
depth=28, | ||
num_heads=16, | ||
mlp_ratio=4.0, | ||
class_dropout_prob=0.1, | ||
num_classes=1000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput | ||
|
||
|
||
class DiTPipeline(DiffusionPipeline): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean!
Failing tests seem to be flaky - can't reproduce locally. All slow pipeline tests were passing locally -> let's merge it ❤️ @kashif I did some changes. Official model weights are now: Also added the image net classes directly to the config. Great job on the PR! |
i should be the one to thank you!! |
* added dit model * import * initial pipeline * initial convert script * initial pipeline * make style * raise valueerror * single function * rename classes * use DDIMScheduler * timesteps embedder * samples to cpu * fix var names * fix numpy type * use timesteps class for proj * fix typo * fix arg name * flip_sin_to_cos and better var names * fix C shape cal * make style * remove unused imports * cleanup * add back patch_size * initial dit doc * typo * Update docs/source/api/pipelines/dit.mdx Co-authored-by: Suraj Patil <[email protected]> * added copyright license headers * added example usage and toc * fix variable names asserts * remove comment * added docs * fix typo * upstream changes * set proper device for drop_ids * added initial dit pipeline test * update docs * fix imports * make fix-copies * isort * fix imports * get rid of more magic numbers * fix code when guidance is off * remove block_kwargs * cleanup script * removed to_2tuple * use FeedForward class instead of another MLP * style * work on mergint DiTBlock with BasicTransformerBlock * added missing final_dropout and args to BasicTransformerBlock * use norm from block * fix arg * remove unused arg * fix call to class_embedder * use timesteps * make style * attn_output gets multiplied * removed commented code * use Transformer2D * use self.is_input_patches * fix flags * fixed conversion to use Transformer2DModel * fixes for pipeline * remove dit.py * fix timesteps device * use randn_tensor and fix fp16 inf. * timesteps_emb already the right dtype * fix dit test class * fix test and style * fix norm2 usage in vq-diffusion * added author names to pipeline and lmagenet labels link * fix tests * use norm_type as string * rename dit to transformer * fix name * fix test * set norm_type = "layer" by default * fix tests * do not skip common tests * Update src/diffusers/models/attention.py Co-authored-by: Suraj Patil <[email protected]> * revert AdaLayerNorm API * fix norm_type name * make sure all components are in eval mode * revert norm2 API * compact * finish deprecation * add slow tests * remove @ * refactor some stuff * upload * Update src/diffusers/pipelines/dit/pipeline_dit.py * finish more * finish docs * improve docs * finish docs Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: William Berman <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
DiT pipeline: https://github.com/facebookresearch/DiT