diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index a42187fadea1..12ff40bbd680 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -568,7 +568,9 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index d6e4b17ba889..e10e6d302457 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -557,7 +557,9 @@ def transform_images(examples): clean_images = batch["input"] # Sample noise that we'll add to the images - noise = torch.randn(clean_images.shape).to(clean_images.device) + noise = torch.randn( + clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16) + ).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint(