From 8baa5c2da9f78a09a2a25b1ae843989f7579fad7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 Jun 2023 21:20:06 +0000 Subject: [PATCH 1/7] avoid upcasting by assigning dtype to noise tensor --- .../unconditional_image_generation/train_unconditional.py | 2 +- examples/unconditional_image_generation/train_unconditional.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 1b38036d82c0..92bbb3769c4f 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -500,7 +500,7 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn(clean_images.shape, dtype=(torch.float32 if args.mixed_precision=="no" else torch.float16)).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 282f52101a3c..ffd5c69224d6 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -562,7 +562,7 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn(clean_images.shape, dtype=(torch.float32 if args.mixed_precision=="no" else torch.float16)).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( From 1c2617d19d81bbdbea3c661a75b9db0b5d1b294d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 Jun 2023 21:27:32 +0000 Subject: [PATCH 2/7] make style --- .../unconditional_image_generation/train_unconditional.py | 4 +++- .../unconditional_image_generation/train_unconditional.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 92bbb3769c4f..f572d5e787d8 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -500,7 +500,9 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape, dtype=(torch.float32 if args.mixed_precision=="no" else torch.float16)).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index ffd5c69224d6..4969b72a6cba 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -562,7 +562,9 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape, dtype=(torch.float32 if args.mixed_precision=="no" else torch.float16)).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( From b0222f079924a8411afd64598d094d4ad867fe0b Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Thu, 22 Jun 2023 16:18:06 -0700 Subject: [PATCH 3/7] Update train_unconditional.py --- .../unconditional_image_generation/train_unconditional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index f572d5e787d8..2ab58b69ef31 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -501,7 +501,8 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn( - clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + clean_images.shape, + dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image From 414cfb99cc1be98937e885708d5e691956d75eb7 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Thu, 22 Jun 2023 16:18:32 -0700 Subject: [PATCH 4/7] Update train_unconditional.py --- examples/unconditional_image_generation/train_unconditional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 4969b72a6cba..47fc87a8bf8c 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -563,7 +563,8 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn( - clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + clean_images.shape, + dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image From be73fe34945fc88ccdea2b48c28d9353990c9458 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 28 Jun 2023 19:40:07 +0000 Subject: [PATCH 5/7] make style --- .../unconditional_image_generation/train_unconditional.py | 3 +-- examples/unconditional_image_generation/train_unconditional.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 84d98475e577..12ff40bbd680 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -569,8 +569,7 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn( - clean_images.shape, - dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index de8d8358de19..e10e6d302457 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -558,8 +558,7 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn( - clean_images.shape, - dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image From b64d361f9e85b9d3c194b00ac660b27c3f43d158 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 28 Jun 2023 21:27:00 +0000 Subject: [PATCH 6/7] add unit test for pickle --- tests/models/test_models_unet_2d_condition.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 24da508227d2..6e64bf8c0055 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +from dataclasses import dataclass import gc import os import tempfile @@ -33,6 +35,7 @@ torch_all_close, torch_device, ) +from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism @@ -1088,3 +1091,13 @@ def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): expected_output_slice = torch.tensor(expected_slice) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + def test_pickle(): + @dataclass + class NetParams(BaseOutput): + sample: torch.FloatTensor + + m = NetParams(sample=torch.randn(1, 10)) + n = copy.copy(m) + + assert m == n \ No newline at end of file From 60446d45c183082da7355da3904d31a377174b04 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 28 Jun 2023 21:28:22 +0000 Subject: [PATCH 7/7] revert change --- tests/models/test_models_unet_2d_condition.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 6e64bf8c0055..24da508227d2 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -from dataclasses import dataclass import gc import os import tempfile @@ -35,7 +33,6 @@ torch_all_close, torch_device, ) -from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism @@ -1091,13 +1088,3 @@ def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): expected_output_slice = torch.tensor(expected_slice) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) - - def test_pickle(): - @dataclass - class NetParams(BaseOutput): - sample: torch.FloatTensor - - m = NetParams(sample=torch.randn(1, 10)) - n = copy.copy(m) - - assert m == n \ No newline at end of file