Skip to content

[enhancement]: Graphcore IPU Support #2120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 task done
brucethemoose opened this issue Dec 24, 2022 · 6 comments
Closed
1 task done

[enhancement]: Graphcore IPU Support #2120

brucethemoose opened this issue Dec 24, 2022 · 6 comments
Labels
enhancement New feature or request Inactive Issue

Comments

@brucethemoose
Copy link

brucethemoose commented Dec 24, 2022

Is there an existing issue for this?

  • I have searched the existing issues

Contact Details

InvokeAI Discord user of the same name.

What should this feature add?

Hello, since yall are migrating to diffusers anyway (#1583) , would you consider adding Graphcore IPU support as seen in the repo here?

https://www.graphcore.ai/posts/how-to-run-stable-diffusion-inference-on-ipus-with-paperspace

See ipu_models.py in the text-to-image demo, it looks like a fairly simple extension of StableDiffusionPipeline

I mention this because Paperspace is offering some particularly beefy free IPU instances now. Note that Nvidia claims an RTX 4090 is 165 FP16 tensor tflops:
Screenshot_4

@brucethemoose brucethemoose added the enhancement New feature or request label Dec 24, 2022
@brucethemoose
Copy link
Author

(Note that I would make a crude attempt to hack this in myself and test it, but it looks like the diffusers implementation isn't done yet).

@keturn
Copy link
Contributor

keturn commented Dec 24, 2022

Their pipeline code is here: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py

Seems totally plausible. I just don't know how to do it cleanly yet. So much of this stuff is very subclass-happy -- including the work-in-progress pipeline in #1583 -- and that's not the best for composition. They also monkey with the cross-attention code which might conflict with some of Invoke's own desire to monkey with cross-attention code.

If you want to work on this, I think a good next step would be to take a look at this new diffusers API for cross-attention: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py

See if you can use that to replace graphcore's current override_attention method. And if you can't, raise that as an issue on the diffusers repo, because that cross-attention API isn't released yet and that could be good design feedback for them.

@brucethemoose
Copy link
Author

brucethemoose commented Dec 26, 2022

Hmmm, looks like the function descriptions are accurate: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py#L36

Overriding this implementation as the torch.baddbmm op is not registered.

I can't find it at the moment, but I read on some Graphcore blog post that torch.baddbmm is not supported on CPU. Replacing the attention OR sliced attention override with the diffusers versions always executes this, which trips it up:

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py#L184

Cross Attention Error
---------------------------------------------------------------------------
  Error                                     Traceback (most recent call last)
  <ipython-input-10-3a9910a550ac> in <module>
  ----> 1 pipe("apple", height=image_height, width=image_width, num_inference_steps=25, guidance_scale=9);
  
  /usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
       26         def decorate_context(*args, **kwargs):
       27             with self.__class__():
  ---> 28                 return func(*args, **kwargs)
       29         return cast(F, decorate_context)
       30 
  
  /usr/local/lib/python3.8/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py in __call__(self, prompt, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps)
      527 
      528                 # predict the noise residual
  --> 529                 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
      530 
      531                 # perform guidance
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1101                 or _global_forward_hooks or _global_forward_pre_hooks):
  -> 1102             return forward_call(*input, **kwargs)
     1103         # Do not call functions when jit is used
     1104         full_backward_hooks, non_full_backward_hooks = [], []
  
  /notebooks/stable-diffusion/ipu_models.py in forward(self, sample, timestep, encoder_hidden_states, return_dict)
      105         encoder_hidden_states = encoder_hidden_states.to(self.input_dtype)
      106 
  --> 107         ret = self.unet(sample, timestep, encoder_hidden_states)
      108 
      109         ret.sample = ret.sample.to(self.output_dtype)
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in __call__(self, *args, **kwargs)
      919 
      920         if not self.isCompiled():
  --> 921             self._compile(in_tensors)
      922 
      923         if not self._is_attached:
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(self, *args, **kwargs)
      356         def wrapper(self, *args, **kwargs):
      357             with self._profiling.tracepoint(label):  # pylint: disable=protected-access
  --> 358                 return func(self, *args, **kwargs)
      359 
      360         return wrapper
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compile(self, in_tensors)
      644                 self._executable = poptorch_core.compileWithTrace(*trace_args)
      645             else:
  --> 646                 self._executable = self._compileWithDispatch(
      647                     in_tensors_trace_view)
      648 
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compileWithDispatch(self, in_tensors, executable_filename)
      593                                    **in_tensors.kwargs)
      594             else:
  --> 595                 ctx.compile(*in_tensors.args, **in_tensors.kwargs)
      596             self._outputs_structure = ctx.ipu._outputs_structure  # pylint: disable=protected-access
      597 
  
  /usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in compile(self, *args, **kwargs)
      339 
      340     def compile(self, *args, **kwargs):
  --> 341         return self._compileOrLoadExecutable(args, kwargs)
      342 
      343     def loadExecutable(self, filename, *args, **kwargs):
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(*args, **kwargs)
      162     def wrapper(*args, **kwargs):
      163         with OnExit():
  --> 164             return func(*args, **kwargs)
      165 
      166     return wrapper
  
  /usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in _compileOrLoadExecutable(self, args, kwargs, filename)
      380                                                       tensor_args)
      381 
  --> 382             result = self.func(*args, **kwargs)
      383             if result is not None:
      384                 ipu.outputs(result)
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1101                 or _global_forward_hooks or _global_forward_pre_hooks):
  -> 1102             return forward_call(*input, **kwargs)
     1103         # Do not call functions when jit is used
     1104         full_backward_hooks, non_full_backward_hooks = [], []
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, attention_mask, return_dict)
      422         for downsample_block in self.down_blocks:
      423             if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
  --> 424                 sample, res_samples = downsample_block(
      425                     hidden_states=sample,
      426                     temb=emb,
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py in forward(self, hidden_states, temb, encoder_hidden_states, attention_mask)
      775             else:
      776                 hidden_states = resnet(hidden_states, temb)
  --> 777                 hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
      778 
      779             output_states += (hidden_states,)
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, return_dict)
      214         # 2. Blocks
      215         for block in self.transformer_blocks:
  --> 216             hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
      217 
      218         # 3. Output
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, attention_mask)
      488             )
      489         else:
  --> 490             hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
      491 
      492         if self.attn2 is not None:
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, attention_mask)
      638                 hidden_states = self._attention(query, key, value, attention_mask)
      639             else:
  --> 640                 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
      641 
      642         # linear proj
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask)
      695                 key_slice = key_slice.float()
      696 
  --> 697             attn_slice = torch.baddbmm(
      698                 torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
      699                 query_slice,
  
  Error: In poptorch/source/dispatch_tracer/dispatchers/MLIRDispatch.cpp:372: 'poptorch_cpp_error': No shape inference handler for aten::baddbmm

Not sure if there is an easy way to work around that.

@Lime-Cakes
Copy link

I worked with the mentioned code a few weeks ago (on optimum, but same attention overwrite). The code is for older attention (before diffuser refactored much of the attention code). It should be possible to get it to work without too much hassle, since it's just replacing attention with simpler/less efficient code.

Though, there is another major problem that require solving. You will need to compile the model and cache the complied model somewhere, for every possible resolution.

The first time the model is run, it will compile itself. This process takes around 15 minute on the free pod16 machine offered a while ago. After it's compiled, feeding it a different resolution will trigger an error. The only way to avoid that error is to delete the model and recompile.

The recompile can be skipped if you load a compiled model. So unless you want the user to wait 15 minute for an image each time they pick a different image resolution, caching all possible resolution is sort of the only option.

@brucethemoose
Copy link
Author

brucethemoose commented Dec 28, 2022

I worked with the mentioned code a few weeks ago (on optimum, but same attention overwrite). The code is for older attention (before diffuser refactored much of the attention code). It should be possible to get it to work without too much hassle, since it's just replacing attention with simpler/less efficient code.

Though, there is another major problem that require solving. You will need to compile the model and cache the complied model somewhere, for every possible resolution.

The first time the model is run, it will compile itself. This process takes around 15 minute on the free pod16 machine offered a while ago. After it's compiled, feeding it a different resolution will trigger an error. The only way to avoid that error is to delete the model and recompile.

The recompile can be skipped if you load a compiled model. So unless you want the user to wait 15 minute for an image each time they pick a different image resolution, caching all possible resolution is sort of the only option.

Yes I noticed this on the paperspace Pod4 demo more recently. I also noticed that it only seems to compile on a single thread, and that changing num_images_per_prompt to 2 takes so long to compile that it makes the kernel time out (after ~30min) before it finishes.

The free Pod4 instance is a 56 thread vm with gobs of RAM, so in some kind of theoretical paperspace notebook, some resolution/model/batch combinations could be selected via the UI at the start and they could be compiled in parallel at startup and stored? But if this is too difficult, I am content with a single resolution/model combination for each existing pipe.

On the topic of an InvokeAI paperspace notebook, instead if recreating the UI, I think a workaround like this would allow the user to access the notebook directly:

https://nni.readthedocs.io/en/stable/sharings/nni_colab_support.html

@github-actions
Copy link

There has been no activity in this issue for 14 days. If this issue is still being experienced, please reply with an updated confirmation that the issue is still being experienced with the latest release.

@Millu Millu closed this as not planned Won't fix, can't repro, duplicate, stale Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Inactive Issue
Projects
None yet
Development

No branches or pull requests

4 participants