@@ -112,6 +112,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
112112
113113def parse_args ():
114114 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
115+ parser .add_argument (
116+ "--input_pertubation" , type = float , default = 0 , help = "The scale of input pretubation. Recommended 0.1."
117+ )
115118 parser .add_argument (
116119 "--pretrained_model_name_or_path" ,
117120 type = str ,
@@ -801,15 +804,19 @@ def collate_fn(examples):
801804 noise += args .noise_offset * torch .randn (
802805 (latents .shape [0 ], latents .shape [1 ], 1 , 1 ), device = latents .device
803806 )
804-
807+ if args .input_pertubation :
808+ new_noise = noise + args .input_pertubation * torch .randn_like (noise )
805809 bsz = latents .shape [0 ]
806810 # Sample a random timestep for each image
807811 timesteps = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device )
808812 timesteps = timesteps .long ()
809813
810814 # Add noise to the latents according to the noise magnitude at each timestep
811815 # (this is the forward diffusion process)
812- noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
816+ if args .input_pertubation :
817+ noisy_latents = noise_scheduler .add_noise (latents , new_noise , timesteps )
818+ else :
819+ noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
813820
814821 # Get the text embedding for conditioning
815822 encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
0 commit comments