Skip to content

Better Cross Attention Support For Graphcore IPUs #1830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
brucethemoose opened this issue Dec 26, 2022 · 9 comments
Closed

Better Cross Attention Support For Graphcore IPUs #1830

brucethemoose opened this issue Dec 26, 2022 · 9 comments
Labels
stale Issues that haven't received updates

Comments

@brucethemoose
Copy link

Hello. In Graphcore's Stable Diffusion demo, it seems they replace your cross attention implementation with their own, to work around some issues:

https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py#L35-L93

I modified their demo to pull the latest diffusers release and skip the override, and it seems that the issue is still present, as I elaborated on here:

invoke-ai/InvokeAI#2120 (comment)

I was told (in the above issue) that their cross attention override makes Graphcore IPU compatibility in InvokeAI messy, that your cross attention API isn't released yet, and that this may be useful feedback?

I'm a humble end user in way over my head, so let me know if I am way off base here :P

@patrickvonplaten
Copy link
Contributor

Sorry just to better understand, Graphcore worked with diffusers before we've merged: #1639

And now it doesn't work anymore?

@patrickvonplaten
Copy link
Contributor

I guess we should ping gradient-ai to update their code to the new format soon?

@brucethemoose
Copy link
Author

brucethemoose commented Dec 27, 2022

Sorry just to better understand, Graphcore worked with diffusers before we've merged: #1639

And now it doesn't work anymore?

Their Stable Diffusion demo still works with the diffusers pipeline, but only with this override that was apparently posted 3 weeks ago, before that merge you mentioned: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py#L35-L93

If I comment out either of these 2 lines so the demo uses the regular diffusers cross attention code:
https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py#L92-L93

I get this error:

Cross Attention Error
---------------------------------------------------------------------------
  Error                                     Traceback (most recent call last)
  <ipython-input-10-3a9910a550ac> in <module>
  ----> 1 pipe("apple", height=image_height, width=image_width, num_inference_steps=25, guidance_scale=9);
  
  /usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
       26         def decorate_context(*args, **kwargs):
       27             with self.__class__():
  ---> 28                 return func(*args, **kwargs)
       29         return cast(F, decorate_context)
       30 
  
  /usr/local/lib/python3.8/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py in __call__(self, prompt, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps)
      527 
      528                 # predict the noise residual
  --> 529                 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
      530 
      531                 # perform guidance
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1101                 or _global_forward_hooks or _global_forward_pre_hooks):
  -> 1102             return forward_call(*input, **kwargs)
     1103         # Do not call functions when jit is used
     1104         full_backward_hooks, non_full_backward_hooks = [], []
  
  /notebooks/stable-diffusion/ipu_models.py in forward(self, sample, timestep, encoder_hidden_states, return_dict)
      105         encoder_hidden_states = encoder_hidden_states.to(self.input_dtype)
      106 
  --> 107         ret = self.unet(sample, timestep, encoder_hidden_states)
      108 
      109         ret.sample = ret.sample.to(self.output_dtype)
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in __call__(self, *args, **kwargs)
      919 
      920         if not self.isCompiled():
  --> 921             self._compile(in_tensors)
      922 
      923         if not self._is_attached:
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(self, *args, **kwargs)
      356         def wrapper(self, *args, **kwargs):
      357             with self._profiling.tracepoint(label):  # pylint: disable=protected-access
  --> 358                 return func(self, *args, **kwargs)
      359 
      360         return wrapper
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compile(self, in_tensors)
      644                 self._executable = poptorch_core.compileWithTrace(*trace_args)
      645             else:
  --> 646                 self._executable = self._compileWithDispatch(
      647                     in_tensors_trace_view)
      648 
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compileWithDispatch(self, in_tensors, executable_filename)
      593                                    **in_tensors.kwargs)
      594             else:
  --> 595                 ctx.compile(*in_tensors.args, **in_tensors.kwargs)
      596             self._outputs_structure = ctx.ipu._outputs_structure  # pylint: disable=protected-access
      597 
  
  /usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in compile(self, *args, **kwargs)
      339 
      340     def compile(self, *args, **kwargs):
  --> 341         return self._compileOrLoadExecutable(args, kwargs)
      342 
      343     def loadExecutable(self, filename, *args, **kwargs):
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(*args, **kwargs)
      162     def wrapper(*args, **kwargs):
      163         with OnExit():
  --> 164             return func(*args, **kwargs)
      165 
      166     return wrapper
  
  /usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in _compileOrLoadExecutable(self, args, kwargs, filename)
      380                                                       tensor_args)
      381 
  --> 382             result = self.func(*args, **kwargs)
      383             if result is not None:
      384                 ipu.outputs(result)
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1101                 or _global_forward_hooks or _global_forward_pre_hooks):
  -> 1102             return forward_call(*input, **kwargs)
     1103         # Do not call functions when jit is used
     1104         full_backward_hooks, non_full_backward_hooks = [], []
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, attention_mask, return_dict)
      422         for downsample_block in self.down_blocks:
      423             if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
  --> 424                 sample, res_samples = downsample_block(
      425                     hidden_states=sample,
      426                     temb=emb,
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py in forward(self, hidden_states, temb, encoder_hidden_states, attention_mask)
      775             else:
      776                 hidden_states = resnet(hidden_states, temb)
  --> 777                 hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
      778 
      779             output_states += (hidden_states,)
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, return_dict)
      214         # 2. Blocks
      215         for block in self.transformer_blocks:
  --> 216             hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
      217 
      218         # 3. Output
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, attention_mask)
      488             )
      489         else:
  --> 490             hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
      491 
      492         if self.attn2 is not None:
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, attention_mask)
      638                 hidden_states = self._attention(query, key, value, attention_mask)
      639             else:
  --> 640                 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
      641 
      642         # linear proj
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask)
      695                 key_slice = key_slice.float()
      696 
  --> 697             attn_slice = torch.baddbmm(
      698                 torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
      699                 query_slice,
  
  Error: In poptorch/source/dispatch_tracer/dispatchers/MLIRDispatch.cpp:372: 'poptorch_cpp_error': No shape inference handler for aten::baddbmm

And I was told this override makes integration with InvokeAI difficult.

@patrickvonplaten
Copy link
Contributor

Uff sorry, I'm really not an expert on Graphcore IPUs

@Lime-Cakes
Copy link
Contributor

The better option might be to wait for graphcore to finish adding support for all its missing operations, assuming they do add them. Poptorch is still rather new and unstable, so support for it is likely difficult.

@brucethemoose
Copy link
Author

brucethemoose commented Jan 8, 2023

The better option might be to wait for graphcore to finish adding support for all its missing operations, assuming they do add them. Poptorch is still rather new and unstable, so support for it is likely difficult.

One thing I realized a few days ago is that my testing environment (A Paperspace Gradient POD4 notebook) is pretty old. Its running Ubuntu 20.04, Python 3.8, and what I suspect is an old build of poptorch... and I cant build or update poptorch without some kind of Graphcore login to download the SDK.

torch.baddbmm may already be implemented, for all I know. I've emailed paperspace to see if they can provide newer IPU images or at least newer versions of poptorch.

@Lime-Cakes
Copy link
Contributor

The better option might be to wait for graphcore to finish adding support for all its missing operations, assuming they do add them. Poptorch is still rather new and unstable, so support for it is likely difficult.

One thing I realized a few days ago is that my testing environment (A Paperspace Gradient POD4 notebook) is pretty old. Its running Ubuntu 20.04, Python 3.8, and what I suspect is an old build of poptorch... and I cant build or update poptorch without some kind of Graphcore login to download the SDK.

torch.baddbmm may already be implemented, for all I know. I've emailed paperspace to see if they can provide newer IPU images or at least newer versions of poptorch.

It's not in 3.0. I tested with poptorch 3.0 and it does not support baddbmm. Paperspace is on 2.6. You might have to contact graphcore instead.

@brucethemoose
Copy link
Author

brucethemoose commented Jan 8, 2023

There are some commits made specifically for baddbmm in the recent 3.1 release:

https://github.com/graphcore/poptorch/search?q=baddbmm&type=commits

@Lime-Cakes Did you manage to update poptorch to 3.0 in a Gradient notebook, or were you testing in some other environment? I want to give 3.1 a shot (as I want to raise an issue on the poptorch repo if necessary), but I have no clue how to update it without a login.

@github-actions
Copy link
Contributor

github-actions bot commented Feb 2, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

3 participants