Skip to content

Conversation

@Bissmella
Copy link

@Bissmella Bissmella commented Nov 21, 2025

What does this PR do?

This is a draft implementation of the Unified SP attention approach.

  • Implements _all_to_all_dim_exchange with custom scatter and gather indices
  • Implements TemplatedUnifiedAttention

Core implementation complete, needs:

  • Testing
  • Validation

@sayakpaul
Copy link
Member

It would be nice to get a testing script so that we can quickly check things.

@KarthikSundar2002
Copy link

I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward??

@Bissmella Bissmella force-pushed the unified-SP-attention branch from a244006 to 9dee8f8 Compare November 24, 2025 10:54
@Bissmella Bissmella marked this pull request as ready for review November 24, 2025 10:56
@Bissmella Bissmella force-pushed the unified-SP-attention branch from 9dee8f8 to 9ebcff5 Compare November 24, 2025 23:00
@sayakpaul
Copy link
Member

Let us know if this is ready for a review!

@Bissmella
Copy link
Author

Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting started on this!

grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))

return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward function has 12 inputs (without ctx (context)) but the backward is giving 11 output. Normally the two should be the same. I was getting an error like this while testing: "RuntimeError: function backward returned an incorrect number of gradients (expected 12, got 11)".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reproducer?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it can be reproduced in this notebook (it happens only during the backward): https://colab.research.google.com/drive/1Ac4nVSVjKHrPpcSRlX0E3NzY0mDEmkMx?usp=sharing

@sayakpaul
Copy link
Member

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.

And it leads to:
https://pastebin.com/A7KkvXH2

@Bissmella
Copy link
Author

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.

And it leads to: https://pastebin.com/A7KkvXH2

I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and native_cudnn attention also does not work on simpler GPUs.
Does this error occur with TemplatedRingAttention alone? It seems the problem arises with out, prev_out, lse, and prev_lse in the second iteration of the for loop, but none of these tensors originates directly from TemplatedUnifiedAttention. I will continue digging more into this and see if I can identify the issue.

@Bissmella
Copy link
Author

Bissmella commented Dec 8, 2025

I am trying with the following code:

import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()
    
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ua.png")
if dist.is_initialized():
    dist.destroy_process_group()

Run the above with torchrun --nproc-per-node 4 check_unified_attention.py.
And it leads to: https://pastebin.com/A7KkvXH2

I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and native_cudnn attention also does not work on simpler GPUs. Does this error occur with TemplatedRingAttention alone? It seems the problem arises with out, prev_out, lse, and prev_lse in the second iteration of the for loop, but none of these tensors originates directly from TemplatedUnifiedAttention. I will continue digging more into this and see if I can identify the issue.

Oooh finally tracked it down and could reproduce it on cpu! The bug is in the TemplatedRingAttention forward function in these lines:

            if _parallel_config.context_parallel_config.convert_to_fp32:
                out = out.to(torch.float32)
                lse = lse.to(torch.float32)

            lse = lse.unsqueeze(-1)
            if prev_out is not None:
                out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
                lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
            prev_out = out
            prev_lse = lse

        out = out.to(query.dtype)
        lse = lse.squeeze(-1)

That lse = lse.unsqueeze(-1) is unnecessary and causes the issue because it is already done inside the torch.ops.aten._scaled_dot_product_cudnn_attention used by _cudnn_attention_forward_op. See https://github.com/pytorch/pytorch/blob/7a38744ffa3775ace1df4df1d613bb520eb6e456/torch/_meta_registrations.py#L5733 on meta info about the torch.ops.aten._scaled_dot_product_cudnn_attention.
So should I commit and push the fix just removing that one line?

@sayakpaul
Copy link
Member

Thanks a lot for this investigation. Indeed, that seems to be an issue in PyTorch 2.9. WDYT about the following diff?

Unfold
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index aaa45c757..0efeb2868 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -44,6 +44,7 @@ from ..utils import (
     is_xformers_version,
 )
 from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+from ..utils import is_torch_version
 
 
 if TYPE_CHECKING:
@@ -1186,7 +1187,10 @@ class TemplatedRingAttention(torch.autograd.Function):
                 out = out.to(torch.float32)
                 lse = lse.to(torch.float32)
 
-            lse = lse.unsqueeze(-1)
+            # Refer to:
+            # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+            if is_torch_version("<", "2.9.0"):
+                lse = lse.unsqueeze(-1)
             if prev_out is not None:
                 out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
                 lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1400,7 +1404,10 @@ def TemplatedUnifiedAttention(
     if return_lse:
         # not sure if this is correct: Assuming (based on forward ops in ringAttention) 
         # the lse is of shape (B, S, H_LOCAL)
-        lse = lse.unsqueeze(-1)  # (B, S, H_LOCAL, 1)
+        # Refer to:
+        # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
+        if is_torch_version("<", "2.9.0"):
+            lse = lse.unsqueeze(-1)  # (B, S, H_LOCAL, 1)
         lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
         lse = lse.squeeze(-1)
         return (output, lse)

I also coded up a simple script to compare different backends:

Unfold
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified"],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()

    pipeline = DiffusionPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
    ).to(device)
    # Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
    pipeline.transformer.set_attention_backend("_native_cudnn")

    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    else:
        cp_config = ContextParallelConfig(ulysses_degree=world_size)

    pipeline.transformer.enable_parallelism(config=cp_config)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    generator = torch.Generator().manual_seed(42)
    image = pipeline(
        prompt,
        guidance_scale=3.5,
        num_inference_steps=50,
        generator=generator,
    ).images[0]

    if dist.get_rank() == 0:
        image.save(f"output_{args.cp_backend}.png")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

When I ran the above with torchrun --nproc-per-node 2 check_unified_attention.py --cp-backend {ring,ulysses,unified} (I am on a node of 2 GPUs), I got:

Ring Ulysses Unified
Ring Ulysses Unified

I also changed to cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2) on a node of 4 GPUs, and ran the code with torchrun --nproc-per-node 4 check_unified_attention.py --cp-backend. I got identical output.

@Bissmella
Copy link
Author

I think that is perfect, I didn't know specific about torch 2.9. I will apply the diff.
Thanks a lot for sharing your script and those amazing photos. Should I convert your script to a test and add it in tests? I think that would be good. Or replace the existing one? I can put some more time on cleaning and adding standard test.

I will just do final test on lse on TemplatedUnifiedAttention and correct if anything wrong.
There is a similar issue to this earlier comment in the backward of TemplatedUnifiedAttention and misses one None in the output. Should I add it?

@sayakpaul
Copy link
Member

Should I convert your script to a test and add it in tests? I think that would be good. Or replace the existing one? I can put some more time on cleaning and adding standard test.

We need to add dedicated testing for CP x attention backends, anyway. So, we can skip for now. Sufficient documentation should suffice.

There is a similar issue to this #12693 (comment) in the backward of TemplatedUnifiedAttention and misses one None in the output. Should I add it?

Sounds good!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Let's also add docs and remove test file.

Comment on lines -92 to 93
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 11, 2025

Style bot fixed some files and pushed the changes.

@Bissmella
Copy link
Author

Okay I will add the docs and then remove the test file.

@Bissmella Bissmella force-pushed the unified-SP-attention branch from 7b7b1f4 to e681fe4 Compare December 11, 2025 09:34
@Bissmella
Copy link
Author

Oups! so sorry for the force push. Just resolved a conflict in the distributed_inference.md in docs.
I added the docs and removed the test file.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR. I think it would be great to gather some benchmark numbers between ring vs. ulysses vs. unified to convey the efficiency gains.

Would it be possible to do so?

@apolinario is it possible set up a Space with 4 GPUs (for a brief period) which could be used by @Bissmella to test this a bit?

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 13, 2025

Style bot fixed some files and pushed the changes.

@Bissmella
Copy link
Author

Yes sure, I can do the benchmarking for the three methods.

@sayakpaul
Copy link
Member

Cool, I will be curious for the results!

@sayakpaul
Copy link
Member

@Bissmella any luck getting the benchmark numbers?

@Bissmella
Copy link
Author

Bissmella commented Dec 25, 2025

@sayakpaul, actually, I don't have access to any GPUs. I thought about subscribing to Pro in HF, but I don't know if that will give me access to enough GPUs for benchmarking.
I was waiting for @apolinario to see if I could get access to some GPUs for a few hours. If not, then I will see and find something and will definitely do the benchmarking.

@sayakpaul
Copy link
Member

@Bissmella makes sense. If you could provide me with a script to benchmark, I can help gather the numbers. Meanwhile, let me check internally if we can get you a Space set up.

@Bissmella
Copy link
Author

Okay sure, I will prepare and share a script soon.

@Bissmella
Copy link
Author

Bissmella commented Dec 26, 2025

@sayakpaul based on your earlier script for generating those images I did this script that reports time, throughput (inference steps/seconds), and memory usage for the three methods averaged over 5 runs with different seeds and then saves it to a json file. I hope it runs without bugs :) :

unfold
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
import json
from datetime import datetime


def parse_args():
    parser = argparse.ArgumentParser(description="Benchmark context parallel backends")
    parser.add_argument(
        "--model-id",
        type=str,
        default="black-forest-labs/FLUX.1-dev",
        help="Model ID to benchmark",
    )
    parser.add_argument(
        "--num-inference-steps",
        type=int,
        default=50,
        help="Number of inference steps",
    )
    parser.add_argument(
        "--num-iters",
        type=int,
        default=5,
        help="Number of iterations for benchmarking",
    )
    parser.add_argument(
        "--backends",
        nargs="+",
        choices=["ring", "ulysses", "unified"],
        default=["ring", "ulysses", "unified"],
        help="Backends to benchmark (can specify multiple)",
    )
    parser.add_argument(
        "--output-json",
        type=str,
        default="benchmark_results.json",
        help="Path to save benchmark results JSON",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device, rank, world_size


def measure_memory():
    """Measure current GPU memory usage"""
    allocated = torch.cuda.memory_allocated() / 1024**3  # GB
    reserved = torch.cuda.memory_reserved() / 1024**3  # GB
    max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
    return {
        "allocated_gb": allocated,
        "reserved_gb": reserved,
        "max_allocated_gb": max_allocated,
    }


def benchmark_config(
    pipeline,
    cp_config,
    config_name,
    prompt,
    num_inference_steps,
    num_iters,
    warmup_iters=2,
):
    """Benchmark a specific context parallel configuration"""
    rank = dist.get_rank()
    
    #reset memory stats
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    #config
    pipeline.transformer.enable_parallelism(config=cp_config)
    
    if rank == 0:
        print(f"\n{'='*60}")
        print(f"Benchmarking: {config_name}")
        print(f"{'='*60}")
    
    #warmup
    if rank == 0:
        print(f"Warmup ({warmup_iters} iterations)...")
    
    for i in range(warmup_iters):
        generator = torch.Generator().manual_seed(42)
        _ = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
    
    #clearing cache
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    

    if rank == 0:
        print(f"Running benchmark ({num_iters} iterations)...")
    
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for i in range(num_iters):
        generator = torch.Generator().manual_seed(42 + i)  #different seed per iter
        image = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
    end.record()
    
    torch.cuda.synchronize()
    elapsed_ms = start.elapsed_time(end) / num_iters
    
    #measure memory
    memory_stats = measure_memory()
    
    if rank == 0:
        image.save(f"output_{config_name}.png")
    
    #gathering results
    results_tensor = torch.tensor(
        [elapsed_ms, memory_stats["max_allocated_gb"]], 
        device="cuda"
    )
    
    if rank == 0:
        gathered = [torch.zeros_like(results_tensor) for _ in range(dist.get_world_size())]
        dist.gather(results_tensor, gather_list=gathered, dst=0)
        
        #averaging across ranks
        avg_time = sum(t[0].item() for t in gathered) / len(gathered)
        max_memory = max(t[1].item() for t in gathered)
        
        throughput = num_inference_steps / (avg_time / 1000)  # steps/sec
        
        print(f"\nResults for {config_name}:")
        print(f"  Average time per iteration: {avg_time:.3f} ms")
        print(f"  Throughput: {throughput:.3f} steps/sec")
        print(f"  Peak memory (max across ranks): {max_memory:.3f} GB")
        
        return {
            "config_name": config_name,
            "avg_time_ms": avg_time,
            "throughput_steps_per_sec": throughput,
            "peak_memory_gb": max_memory,
            "num_devices": dist.get_world_size(),
            **memory_stats,
        }
    else:
        dist.gather(results_tensor, dst=0)
        return None


def main():
    args = parse_args()
    
    try:
        device, rank, world_size = setup_distributed()
        
        if rank == 0:
            print(f"Running on {world_size} GPUs")
            print(f"Model: {args.model_id}")
            print(f"Inference steps: {args.num_inference_steps}")
            print(f"Benchmark iterations: {args.num_iters}")
        
        if rank == 0:
            print("\nLoading pipeline...")
        
        pipeline = DiffusionPipeline.from_pretrained(
            args.model_id,
            torch_dtype=torch.bfloat16,
        ).to(device)
        
        pipeline.transformer.set_attention_backend("_native_cudnn")
        
        prompt = """
        cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
        highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
        """
        
        results = []
        
        #benchmarking
        configs = []
        
        if "ulysses" in args.backends:
            configs.append((
                ContextParallelConfig(ulysses_degree=world_size),
                "ulysses"
            ))
        
        if "ring" in args.backends:
            configs.append((
                ContextParallelConfig(ring_degree=world_size),
                "ring"
            ))
        
        if "unified" in args.backends:
            if world_size >= 4:
                configs.append((
                    ContextParallelConfig(
                        ulysses_degree=world_size // 2,
                        ring_degree=world_size // 2
                    ),
                    "unified_balanced"
                ))
            else:
                if rank == 0:
                    print("Warning: Skipping unified (requires at least 4 GPUs)")
        
        for cp_config, name in configs:
            result = benchmark_config(
                pipeline,
                cp_config,
                name,
                prompt,
                args.num_inference_steps,
                args.num_iters,
            )
            if rank == 0 and result:
                results.append(result)
        
        if rank == 0:
            output = {
                "timestamp": datetime.now().isoformat(),
                "model_id": args.model_id,
                "num_devices": world_size,
                "num_inference_steps": args.num_inference_steps,
                "num_iters": args.num_iters,
                "results": results,
            }
            
            with open(args.output_json, "w") as f:
                json.dump(output, f, indent=2)
            
            print(f"\n{'='*60}")
            print("Summary:")
            print(f"{'='*60}")
            for r in results:
                print(f"{r['config_name']:20s}: {r['avg_time_ms']:8.3f} ms/iter, "
                      f"{r['throughput_steps_per_sec']:6.2f} steps/sec, "
                      f"{r['peak_memory_gb']:5.2f} GB")
            
            print(f"\nResults saved to {args.output_json}")
    
    except Exception as e:
        print(f"Error on rank {dist.get_rank() if dist.is_initialized() else 'unknown'}: {e}")
        raise
    
    finally:
        if dist.is_initialized():
            dist.destroy_process_group()


if __name__ == "__main__":
    main()

running with 4 processes:

torchrun --nproc_per_node=4 benchmark.py

@sayakpaul
Copy link
Member

@Bissmella could you fix the formatting of the code so that it's easier to eyeball?

@Bissmella
Copy link
Author

Ooh sorry. now it should be okay.

@sayakpaul
Copy link
Member

Got this error: https://pastebin.com/NK5EWtE8

@Bissmella
Copy link
Author

So created pipeline from scratch for each method. Hopefully this one works:

unfold
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
import json
from datetime import datetime


def parse_args():
    parser = argparse.ArgumentParser(description="Benchmark context parallel backends")
    parser.add_argument(
        "--model-id",
        type=str,
        default="black-forest-labs/FLUX.1-dev",
        help="Model ID to benchmark",
    )
    parser.add_argument(
        "--num-inference-steps",
        type=int,
        default=50,
        help="Number of inference steps",
    )
    parser.add_argument(
        "--num-iters",
        type=int,
        default=5,
        help="Number of iterations for benchmarking",
    )
    parser.add_argument(
        "--backends",
        nargs="+",
        choices=["ring", "ulysses", "unified"],
        default=["ring", "ulysses", "unified"],
        help="Backends to benchmark (can specify multiple)",
    )
    parser.add_argument(
        "--output-json",
        type=str,
        default="benchmark_results.json",
        help="Path to save benchmark results JSON",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device, rank, world_size


def measure_memory():
    """Measure current GPU memory usage"""
    allocated = torch.cuda.memory_allocated() / 1024**3  # GB
    reserved = torch.cuda.memory_reserved() / 1024**3  # GB
    max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
    return {
        "allocated_gb": allocated,
        "reserved_gb": reserved,
        "max_allocated_gb": max_allocated,
    }

def get_fresh_pipeline(model_id, device):
  pipeline = DiffusionPipeline.from_pretrained(
      model_id,
      torch_dtype=torch.bfloat16,
  ).to(device)
  
  pipeline.transformer.set_attention_backend("_native_cudnn")

  return pipeline

def benchmark_config(
    model_id,
    device,
    cp_config,
    config_name,
    prompt,
    num_inference_steps,
    num_iters,
    warmup_iters=2,
):
    """Benchmark a specific context parallel configuration"""
    rank = dist.get_rank()
    
    pipeline = get_fresh_pipeline(model_id, device)

    #reset memory stats
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    #config
    pipeline.transformer.enable_parallelism(config=cp_config)
    
    if rank == 0:
        print(f"\n{'='*60}")
        print(f"Benchmarking: {config_name}")
        print(f"{'='*60}")
    
    #warmup
    if rank == 0:
        print(f"Warmup ({warmup_iters} iterations)...")
    
    for i in range(warmup_iters):
        generator = torch.Generator().manual_seed(42)
        _ = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
    
    #clearing cache
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    

    if rank == 0:
        print(f"Running benchmark ({num_iters} iterations)...")
    
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for i in range(num_iters):
        generator = torch.Generator().manual_seed(42 + i)  #different seed per iter
        image = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
    end.record()
    
    torch.cuda.synchronize()
    elapsed_ms = start.elapsed_time(end) / num_iters
    
    #measure memory
    memory_stats = measure_memory()
    
    if rank == 0:
        image.save(f"output_{config_name}.png")
    
    #gathering results
    results_tensor = torch.tensor(
        [elapsed_ms, memory_stats["max_allocated_gb"]], 
        device="cuda"
    )

    del pipeline
    torch.cuda.empty_cache()

    if rank == 0:
        gathered = [torch.zeros_like(results_tensor) for _ in range(dist.get_world_size())]
        dist.gather(results_tensor, gather_list=gathered, dst=0)
        
        #averaging across ranks
        avg_time = sum(t[0].item() for t in gathered) / len(gathered)
        max_memory = max(t[1].item() for t in gathered)
        
        throughput = num_inference_steps / (avg_time / 1000)  # steps/sec
        
        print(f"\nResults for {config_name}:")
        print(f"  Average time per iteration: {avg_time:.3f} ms")
        print(f"  Throughput: {throughput:.3f} steps/sec")
        print(f"  Peak memory (max across ranks): {max_memory:.3f} GB")
        
        return {
            "config_name": config_name,
            "avg_time_ms": avg_time,
            "throughput_steps_per_sec": throughput,
            "peak_memory_gb": max_memory,
            "num_devices": dist.get_world_size(),
            **memory_stats,
        }
    else:
        dist.gather(results_tensor, dst=0)
        return None


def main():
    args = parse_args()
    
    try:
        device, rank, world_size = setup_distributed()
        
        if rank == 0:
            print(f"Running on {world_size} GPUs")
            print(f"Model: {args.model_id}")
            print(f"Inference steps: {args.num_inference_steps}")
            print(f"Benchmark iterations: {args.num_iters}")
        
        prompt = """
        cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
        highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
        """
        
        results = []
        
        #benchmarking
        configs = []
        
        if "ulysses" in args.backends:
            configs.append((
                ContextParallelConfig(ulysses_degree=world_size),
                "ulysses"
            ))
        
        if "ring" in args.backends:
            configs.append((
                ContextParallelConfig(ring_degree=world_size),
                "ring"
            ))
        
        if "unified" in args.backends:
            if world_size >= 4:
                configs.append((
                    ContextParallelConfig(
                        ulysses_degree=world_size // 2,
                        ring_degree=world_size // 2
                    ),
                    "unified_balanced"
                ))
            else:
                if rank == 0:
                    print("Warning: Skipping unified (requires at least 4 GPUs)")
        
        for cp_config, name in configs:
            result = benchmark_config(
                args.model_id,
                device,
                cp_config,
                name,
                prompt,
                args.num_inference_steps,
                args.num_iters,
            )
            if rank == 0 and result:
                results.append(result)
        
        if rank == 0:
            output = {
                "timestamp": datetime.now().isoformat(),
                "model_id": args.model_id,
                "num_devices": world_size,
                "num_inference_steps": args.num_inference_steps,
                "num_iters": args.num_iters,
                "results": results,
            }
            
            with open(args.output_json, "w") as f:
                json.dump(output, f, indent=2)
            
            print(f"\n{'='*60}")
            print("Summary:")
            print(f"{'='*60}")
            for r in results:
                print(f"{r['config_name']:20s}: {r['avg_time_ms']:8.3f} ms/iter, "
                      f"{r['throughput_steps_per_sec']:6.2f} steps/sec, "
                      f"{r['peak_memory_gb']:5.2f} GB")
            
            print(f"\nResults saved to {args.output_json}")
    
    except Exception as e:
        print(f"Error on rank {dist.get_rank() if dist.is_initialized() else 'unknown'}: {e}")
        raise
    
    finally:
        if dist.is_initialized():
            dist.destroy_process_group()


if __name__ == "__main__":
    main()

@sayakpaul
Copy link
Member

I have excellent news:

============================================================
Summary:
============================================================
ulysses             : 6670.789 ms/iter,   7.50 steps/sec, 33.85 GB
ring                : 13076.492 ms/iter,   3.82 steps/sec, 56.02 GB
unified_balanced    : 11068.705 ms/iter,   4.52 steps/sec, 33.85 GB

@DN6 I have done several rounds of reviews on this PR. It would be great to have yours!

@Bissmella
Copy link
Author

Thanks @sayakpaul.
btw I wanted to ask if it would be okay if I start working on the issue # 8673 (regarding attention masking in batch inference)? sorry if irrelevant.

@sayakpaul
Copy link
Member

Sure, feel free to!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants