-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add Unified Sequence Parallel attention #12693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
It would be nice to get a testing script so that we can quickly check things. |
|
I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward?? |
a244006 to
9dee8f8
Compare
9dee8f8 to
9ebcff5
Compare
|
Let us know if this is ready for a review! |
|
Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting started on this!
| grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) | ||
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward function has 12 inputs (without ctx (context)) but the backward is giving 11 output. Normally the two should be the same. I was getting an error like this while testing: "RuntimeError: function backward returned an incorrect number of gradients (expected 12, got 11)".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a reproducer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it can be reproduced in this notebook (it happens only during the backward): https://colab.research.google.com/drive/1Ac4nVSVjKHrPpcSRlX0E3NzY0mDEmkMx?usp=sharing
|
I am trying with the following code: import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{dist.get_rank()}")
torch.cuda.set_device(device)
return device
device = setup_distributed()
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]
if dist.get_rank() == 0:
image.save("output_ua.png")
if dist.is_initialized():
dist.destroy_process_group()Run the above with And it leads to: |
I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and |
Oooh finally tracked it down and could reproduce it on cpu! The bug is in the That |
|
I think that is perfect, I didn't know specific about torch 2.9. I will apply the diff. I will just do final test on lse on |
We need to add dedicated testing for CP x attention backends, anyway. So, we can skip for now. Sufficient documentation should suffice.
Sounds good! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Let's also add docs and remove test file.
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | ||
| if self.ring_degree > 1 and self.ulysses_degree > 1: | ||
| raise ValueError( | ||
| "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
| ) | ||
| if self.rotate_method != "allgather": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
Okay I will add the docs and then remove the test file. |
7b7b1f4 to
e681fe4
Compare
|
Oups! so sorry for the force push. Just resolved a conflict in the distributed_inference.md in docs. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR. I think it would be great to gather some benchmark numbers between ring vs. ulysses vs. unified to convey the efficiency gains.
Would it be possible to do so?
@apolinario is it possible set up a Space with 4 GPUs (for a brief period) which could be used by @Bissmella to test this a bit?
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
Co-authored-by: Sayak Paul <[email protected]>
|
Yes sure, I can do the benchmarking for the three methods. |
|
Cool, I will be curious for the results! |
|
@Bissmella any luck getting the benchmark numbers? |
|
@sayakpaul, actually, I don't have access to any GPUs. I thought about subscribing to Pro in HF, but I don't know if that will give me access to enough GPUs for benchmarking. |
|
@Bissmella makes sense. If you could provide me with a script to benchmark, I can help gather the numbers. Meanwhile, let me check internally if we can get you a Space set up. |
|
Okay sure, I will prepare and share a script soon. |
|
@sayakpaul based on your earlier script for generating those images I did this script that reports time, throughput (inference steps/seconds), and memory usage for the three methods averaged over 5 runs with different seeds and then saves it to a json file. I hope it runs without bugs :) : unfoldimport argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
import json
from datetime import datetime
def parse_args():
parser = argparse.ArgumentParser(description="Benchmark context parallel backends")
parser.add_argument(
"--model-id",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="Model ID to benchmark",
)
parser.add_argument(
"--num-inference-steps",
type=int,
default=50,
help="Number of inference steps",
)
parser.add_argument(
"--num-iters",
type=int,
default=5,
help="Number of iterations for benchmarking",
)
parser.add_argument(
"--backends",
nargs="+",
choices=["ring", "ulysses", "unified"],
default=["ring", "ulysses", "unified"],
help="Backends to benchmark (can specify multiple)",
)
parser.add_argument(
"--output-json",
type=str,
default="benchmark_results.json",
help="Path to save benchmark results JSON",
)
return parser.parse_args()
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device, rank, world_size
def measure_memory():
"""Measure current GPU memory usage"""
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
reserved = torch.cuda.memory_reserved() / 1024**3 # GB
max_allocated = torch.cuda.max_memory_allocated() / 1024**3 # GB
return {
"allocated_gb": allocated,
"reserved_gb": reserved,
"max_allocated_gb": max_allocated,
}
def benchmark_config(
pipeline,
cp_config,
config_name,
prompt,
num_inference_steps,
num_iters,
warmup_iters=2,
):
"""Benchmark a specific context parallel configuration"""
rank = dist.get_rank()
#reset memory stats
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
#config
pipeline.transformer.enable_parallelism(config=cp_config)
if rank == 0:
print(f"\n{'='*60}")
print(f"Benchmarking: {config_name}")
print(f"{'='*60}")
#warmup
if rank == 0:
print(f"Warmup ({warmup_iters} iterations)...")
for i in range(warmup_iters):
generator = torch.Generator().manual_seed(42)
_ = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
#clearing cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
if rank == 0:
print(f"Running benchmark ({num_iters} iterations)...")
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(num_iters):
generator = torch.Generator().manual_seed(42 + i) #different seed per iter
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end) / num_iters
#measure memory
memory_stats = measure_memory()
if rank == 0:
image.save(f"output_{config_name}.png")
#gathering results
results_tensor = torch.tensor(
[elapsed_ms, memory_stats["max_allocated_gb"]],
device="cuda"
)
if rank == 0:
gathered = [torch.zeros_like(results_tensor) for _ in range(dist.get_world_size())]
dist.gather(results_tensor, gather_list=gathered, dst=0)
#averaging across ranks
avg_time = sum(t[0].item() for t in gathered) / len(gathered)
max_memory = max(t[1].item() for t in gathered)
throughput = num_inference_steps / (avg_time / 1000) # steps/sec
print(f"\nResults for {config_name}:")
print(f" Average time per iteration: {avg_time:.3f} ms")
print(f" Throughput: {throughput:.3f} steps/sec")
print(f" Peak memory (max across ranks): {max_memory:.3f} GB")
return {
"config_name": config_name,
"avg_time_ms": avg_time,
"throughput_steps_per_sec": throughput,
"peak_memory_gb": max_memory,
"num_devices": dist.get_world_size(),
**memory_stats,
}
else:
dist.gather(results_tensor, dst=0)
return None
def main():
args = parse_args()
try:
device, rank, world_size = setup_distributed()
if rank == 0:
print(f"Running on {world_size} GPUs")
print(f"Model: {args.model_id}")
print(f"Inference steps: {args.num_inference_steps}")
print(f"Benchmark iterations: {args.num_iters}")
if rank == 0:
print("\nLoading pipeline...")
pipeline = DiffusionPipeline.from_pretrained(
args.model_id,
torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
results = []
#benchmarking
configs = []
if "ulysses" in args.backends:
configs.append((
ContextParallelConfig(ulysses_degree=world_size),
"ulysses"
))
if "ring" in args.backends:
configs.append((
ContextParallelConfig(ring_degree=world_size),
"ring"
))
if "unified" in args.backends:
if world_size >= 4:
configs.append((
ContextParallelConfig(
ulysses_degree=world_size // 2,
ring_degree=world_size // 2
),
"unified_balanced"
))
else:
if rank == 0:
print("Warning: Skipping unified (requires at least 4 GPUs)")
for cp_config, name in configs:
result = benchmark_config(
pipeline,
cp_config,
name,
prompt,
args.num_inference_steps,
args.num_iters,
)
if rank == 0 and result:
results.append(result)
if rank == 0:
output = {
"timestamp": datetime.now().isoformat(),
"model_id": args.model_id,
"num_devices": world_size,
"num_inference_steps": args.num_inference_steps,
"num_iters": args.num_iters,
"results": results,
}
with open(args.output_json, "w") as f:
json.dump(output, f, indent=2)
print(f"\n{'='*60}")
print("Summary:")
print(f"{'='*60}")
for r in results:
print(f"{r['config_name']:20s}: {r['avg_time_ms']:8.3f} ms/iter, "
f"{r['throughput_steps_per_sec']:6.2f} steps/sec, "
f"{r['peak_memory_gb']:5.2f} GB")
print(f"\nResults saved to {args.output_json}")
except Exception as e:
print(f"Error on rank {dist.get_rank() if dist.is_initialized() else 'unknown'}: {e}")
raise
finally:
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()running with 4 processes: |
|
@Bissmella could you fix the formatting of the code so that it's easier to eyeball? |
|
Ooh sorry. now it should be okay. |
|
Got this error: https://pastebin.com/NK5EWtE8 |
|
So created pipeline from scratch for each method. Hopefully this one works: unfoldimport argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
import json
from datetime import datetime
def parse_args():
parser = argparse.ArgumentParser(description="Benchmark context parallel backends")
parser.add_argument(
"--model-id",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="Model ID to benchmark",
)
parser.add_argument(
"--num-inference-steps",
type=int,
default=50,
help="Number of inference steps",
)
parser.add_argument(
"--num-iters",
type=int,
default=5,
help="Number of iterations for benchmarking",
)
parser.add_argument(
"--backends",
nargs="+",
choices=["ring", "ulysses", "unified"],
default=["ring", "ulysses", "unified"],
help="Backends to benchmark (can specify multiple)",
)
parser.add_argument(
"--output-json",
type=str,
default="benchmark_results.json",
help="Path to save benchmark results JSON",
)
return parser.parse_args()
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device, rank, world_size
def measure_memory():
"""Measure current GPU memory usage"""
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
reserved = torch.cuda.memory_reserved() / 1024**3 # GB
max_allocated = torch.cuda.max_memory_allocated() / 1024**3 # GB
return {
"allocated_gb": allocated,
"reserved_gb": reserved,
"max_allocated_gb": max_allocated,
}
def get_fresh_pipeline(model_id, device):
pipeline = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
return pipeline
def benchmark_config(
model_id,
device,
cp_config,
config_name,
prompt,
num_inference_steps,
num_iters,
warmup_iters=2,
):
"""Benchmark a specific context parallel configuration"""
rank = dist.get_rank()
pipeline = get_fresh_pipeline(model_id, device)
#reset memory stats
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
#config
pipeline.transformer.enable_parallelism(config=cp_config)
if rank == 0:
print(f"\n{'='*60}")
print(f"Benchmarking: {config_name}")
print(f"{'='*60}")
#warmup
if rank == 0:
print(f"Warmup ({warmup_iters} iterations)...")
for i in range(warmup_iters):
generator = torch.Generator().manual_seed(42)
_ = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
#clearing cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
if rank == 0:
print(f"Running benchmark ({num_iters} iterations)...")
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(num_iters):
generator = torch.Generator().manual_seed(42 + i) #different seed per iter
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end) / num_iters
#measure memory
memory_stats = measure_memory()
if rank == 0:
image.save(f"output_{config_name}.png")
#gathering results
results_tensor = torch.tensor(
[elapsed_ms, memory_stats["max_allocated_gb"]],
device="cuda"
)
del pipeline
torch.cuda.empty_cache()
if rank == 0:
gathered = [torch.zeros_like(results_tensor) for _ in range(dist.get_world_size())]
dist.gather(results_tensor, gather_list=gathered, dst=0)
#averaging across ranks
avg_time = sum(t[0].item() for t in gathered) / len(gathered)
max_memory = max(t[1].item() for t in gathered)
throughput = num_inference_steps / (avg_time / 1000) # steps/sec
print(f"\nResults for {config_name}:")
print(f" Average time per iteration: {avg_time:.3f} ms")
print(f" Throughput: {throughput:.3f} steps/sec")
print(f" Peak memory (max across ranks): {max_memory:.3f} GB")
return {
"config_name": config_name,
"avg_time_ms": avg_time,
"throughput_steps_per_sec": throughput,
"peak_memory_gb": max_memory,
"num_devices": dist.get_world_size(),
**memory_stats,
}
else:
dist.gather(results_tensor, dst=0)
return None
def main():
args = parse_args()
try:
device, rank, world_size = setup_distributed()
if rank == 0:
print(f"Running on {world_size} GPUs")
print(f"Model: {args.model_id}")
print(f"Inference steps: {args.num_inference_steps}")
print(f"Benchmark iterations: {args.num_iters}")
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
results = []
#benchmarking
configs = []
if "ulysses" in args.backends:
configs.append((
ContextParallelConfig(ulysses_degree=world_size),
"ulysses"
))
if "ring" in args.backends:
configs.append((
ContextParallelConfig(ring_degree=world_size),
"ring"
))
if "unified" in args.backends:
if world_size >= 4:
configs.append((
ContextParallelConfig(
ulysses_degree=world_size // 2,
ring_degree=world_size // 2
),
"unified_balanced"
))
else:
if rank == 0:
print("Warning: Skipping unified (requires at least 4 GPUs)")
for cp_config, name in configs:
result = benchmark_config(
args.model_id,
device,
cp_config,
name,
prompt,
args.num_inference_steps,
args.num_iters,
)
if rank == 0 and result:
results.append(result)
if rank == 0:
output = {
"timestamp": datetime.now().isoformat(),
"model_id": args.model_id,
"num_devices": world_size,
"num_inference_steps": args.num_inference_steps,
"num_iters": args.num_iters,
"results": results,
}
with open(args.output_json, "w") as f:
json.dump(output, f, indent=2)
print(f"\n{'='*60}")
print("Summary:")
print(f"{'='*60}")
for r in results:
print(f"{r['config_name']:20s}: {r['avg_time_ms']:8.3f} ms/iter, "
f"{r['throughput_steps_per_sec']:6.2f} steps/sec, "
f"{r['peak_memory_gb']:5.2f} GB")
print(f"\nResults saved to {args.output_json}")
except Exception as e:
print(f"Error on rank {dist.get_rank() if dist.is_initialized() else 'unknown'}: {e}")
raise
finally:
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main() |
|
I have excellent news: @DN6 I have done several rounds of reviews on this PR. It would be great to have yours! |
|
Thanks @sayakpaul. |
|
Sure, feel free to! |
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>



What does this PR do?
This is a draft implementation of the Unified SP attention approach.
_all_to_all_dim_exchangewith custom scatter and gather indicesTemplatedUnifiedAttentionCore implementation complete, needs: