Skip to content

Error "NotImplementedError: Cannot copy out of meta tensor; no data!" on some tests #2500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
antoche opened this issue Feb 26, 2023 · 7 comments · Fixed by #2517
Closed

Error "NotImplementedError: Cannot copy out of meta tensor; no data!" on some tests #2500

antoche opened this issue Feb 26, 2023 · 7 comments · Fixed by #2517
Labels
bug Something isn't working

Comments

@antoche
Copy link
Contributor

antoche commented Feb 26, 2023

Describe the bug

I am hitting this exceptions on various tests:

self = <tests.pipelines.unclip.test_unclip.UnCLIPPipelineFastTests testMethod=test_cpu_offload_forward_pass>

    @unittest.skipIf(
        torch_device != "cuda" or not is_accelerate_available(),
        reason="CPU offload is only available with CUDA and `accelerate` installed",
    )
    def test_cpu_offload_forward_pass(self):
        if not self.test_cpu_offload:
            return
    
        components = self.get_dummy_components()
        pipe = self.pipeline_class(**components)
        pipe.to(torch_device)
        pipe.set_progress_bar_config(disable=None)
    
        inputs = self.get_dummy_inputs(torch_device)
        output_without_offload = pipe(**inputs)[0]
    
        pipe.enable_sequential_cpu_offload()
        inputs = self.get_dummy_inputs(torch_device)
>       output_with_offload = pipe(**inputs)[0]

tests/test_pipelines_common.py:494: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/vol/apps/python/3.9/ext_modules/pytorch/1.12.1/cuda/11.4/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
src/diffusers/pipelines/unclip/pipeline_unclip.py:397: in __call__
    text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
/vol/apps/python/3.9/ext_modules/pytorch/1.12.1/cuda/11.4/torch/nn/modules/module.py:1130: in _call_impl
    return forward_call(*input, **kwargs)
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/hooks.py:148: in new_forward
    output = old_forward(*args, **kwargs)
src/diffusers/pipelines/unclip/text_proj.py:73: in forward
    time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
/vol/apps/python/3.9/ext_modules/pytorch/1.12.1/cuda/11.4/torch/nn/modules/module.py:1130: in _call_impl
    return forward_call(*input, **kwargs)
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/hooks.py:143: in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/hooks.py:252: in pre_forward
    return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/utils/operations.py:126: in send_to_device
    return recursively_apply(_send_to_device, tensor, device, test_type=_has_to_method)
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/utils/operations.py:78: in recursively_apply
    return honor_type(
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/utils/operations.py:49: in honor_type
    return type(obj)(generator)
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/utils/operations.py:81: in <genexpr>
    recursively_apply(
/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/utils/operations.py:97: in recursively_apply
    return func(data, *args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

t = tensor(..., device='meta', size=(2, 32)), device = device(type='cuda', index=0)

    def _send_to_device(t, device):
>       return t.to(device)
E       NotImplementedError: Cannot copy out of meta tensor; no data!

/vol/apps/python/3.9/ext_modules/pyaccelerate/0.13.1/accelerate/utils/operations.py:121: NotImplementedError```

Specifically, on:
- tests/pipelines/paint_by_example/test_paint_by_example.py::PaintByExamplePipelineFastTests::test_cpu_offload_forward_pass
- tests/pipelines/unclip/test_unclip.py::UnCLIPPipelineFastTests::test_cpu_offload_forward_pass
- tests/pipelines/unclip/test_unclip_image_variation.py::UnCLIPImageVariationPipelineFastTests::test_cpu_offload_forward_pass


### Reproduction

Simply running the tests on the repo.

### Logs

```shell
See above

System Info

Tried on multiple linux machines with various Nvidia GPUs.

Running from branch v0.13.1

  • diffusers version: 0.13.1
  • Platform: Linux-4.14.240-weta-20210804-x86_64-with-glibc2.27
  • Python version: 3.9.10
  • PyTorch version (GPU?): 1.12.0a0+git664058f (True)
  • Huggingface_hub version: 0.11.1
  • Transformers version: 4.26.0
  • Accelerate version: 0.13.1
  • xFormers version: 0.0.14.dev
@antoche antoche added the bug Something isn't working label Feb 26, 2023
@pcuenca
Copy link
Member

pcuenca commented Feb 27, 2023

I could reproduce this with accelerate version 0.13.1. I'm currently not aware of a limitation of accelerate regarding CPU offloading, but if that's the case we should verify the minimum version before use.

@patrickvonplaten @muellerzr do you happen to have any insight here?

@muellerzr
Copy link
Contributor

cc @sgugger

@sgugger
Copy link
Contributor

sgugger commented Feb 28, 2023

There are no limitations per se, but bugs were fixed since 0.13.1. Is this an issue that is still ongoing?

@pcuenca
Copy link
Member

pcuenca commented Feb 28, 2023

There are no limitations per se, but bugs were fixed since 0.13.1. Is this an issue that is still ongoing?

No, it was fixed soon enough in 0.14.0. This is the first report we have received, and the OP found it while running tests. Occurrence was rare because it required offloading, an old version of accelerate and the safety checker had to be enabled.

@gabgiani
Copy link

Can someone who have the training process running confirm the versions used for transformers, torch , accelerate , to prevent this error about cannot copy out of meta tensor please?

thanks.

@lianming03
Copy link

Hello, it is possible that the error was caused by transformers==4.26.0. Please try installing version 4.28.0 and check if it resolves the issue.

@pcdilley
Copy link

I'm getting this error with accelerate==0.29.2 and transformers==4.39.3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants