Skip to content

Commit f3d570c

Browse files
hari10599IliaLarchenkopatrickvonplatenyiyixuxuyiyixuxu
authored
feat: allow disk offload for diffuser models (#3285)
* allow disk offload for diffuser models * sort import * add max_memory argument * Changed sample[0] to images[0] (#3304) A pipeline object stores the results in `images` not in `sample`. Current code blocks don't work. * Typo in tutorial (#3295) * Torch compile graph fix (#3286) * fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test * Postprocessing refactor img2img (#3268) * refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * [Torch 2.0 compile] Fix more torch compile breaks (#3313) * Fix more torch compile breaks * add tests * Fix all * fix controlnet * fix more * Add Horace He as co-author. > > Co-authored-by: Horace He <[email protected]> * Add Horace He as co-author. Co-authored-by: Horace He <[email protected]> --------- Co-authored-by: Horace He <[email protected]> * fix: scale_lr and sync example readme and docs. (#3299) * fix: scale_lr and sync example readme and docs. * fix doc link. * Update stable_diffusion.mdx (#3310) fixed import statement * Fix missing variable assign in DeepFloyd-IF-II (#3315) Fix missing variable assign lol * Correct doc build for patch releases (#3316) Update build_documentation.yml * Add Stable Diffusion RePaint to community pipelines (#3320) * Add Stable Diffsuion RePaint to community pipelines - Adds Stable Diffsuion RePaint to community pipelines - Add Readme enty for pipeline * Fix: Remove wrong import - Remove wrong import - Minor change in comments * Fix: Code formatting of stable_diffusion_repaint * Fix: ruff errors in stable_diffusion_repaint * Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314) * fix multistep dpmsolver for cosine schedule (deepfloy-if) * fix a typo * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule * add test, fix style --------- Co-authored-by: Patrick von Platen <[email protected]> * [docs] Improve LoRA docs (#3311) * update docs * add to toctree * apply feedback * Added input pretubation (#3292) * Added input pretubation * Fixed spelling * Update write_own_pipeline.mdx (#3323) * update controlling generation doc with latest goodies. (#3321) * [Quality] Make style (#3341) * Fix config dpm (#3343) * Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344) * add SDE variant of DPM-Solver and DPM-Solver++ * add test * fix typo * fix typo * Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275) The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument. * Rename --only_save_embeds to --save_as_full_pipeline (#3206) * Set --only_save_embeds to False by default Due to how the option is named, it makes more sense to behave like this. * Refactor only_save_embeds to save_as_full_pipeline * [AudioLDM] Generalise conversion script (#3328) Co-authored-by: Patrick von Platen <[email protected]> * Fix TypeError when using prompt_embeds and negative_prompt (#2982) * test: Added test case * fix: fixed type checking issue on _encode_prompt * fix: fixed copies consistency * fix: one copy was not sufficient * Fix pipeline class on README (#3345) Update README.md * Inpainting: typo in docs (#3331) Typo in docs Co-authored-by: Patrick von Platen <[email protected]> * Add `use_Karras_sigmas` to LMSDiscreteScheduler (#3351) * add karras sigma to lms discrete scheduler * add test for lms_scheduler karras * reformat test lms * Batched load of textual inversions (#3277) * Batched load of textual inversions - Only call resize_token_embeddings once per batch as it is the most expensive operation - Allow pretrained_model_name_or_path and token to be an optional list - Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function - Add comment that single files (e.g. .pt/.safetensors) are supported - Add comment for token parameter - Convert token override log message from warning to info * Update src/diffusers/loaders.py Check for duplicate tokens Co-authored-by: Patrick von Platen <[email protected]> * Update condition for None tokens --------- Co-authored-by: Patrick von Platen <[email protected]> * make fix-copies * [docs] Fix docstring (#3334) fix docstring Co-authored-by: Patrick von Platen <[email protected]> * if dreambooth lora (#3360) * update IF stage I pipelines add fixed variance schedulers and lora loading * added kv lora attn processor * allow loading into alternative lora attn processor * make vae optional * throw away predicted variance * allow loading into added kv lora layer * allow load T5 * allow pre compute text embeddings * set new variance type in schedulers * fix copies * refactor all prompt embedding code class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable * fix for when variance type is not defined on scheduler * do not pre compute validation prompt if not present * add example test for if lora dreambooth * add check for train text encoder and pre compute text embeddings * Postprocessing refactor all others (#3337) * add text2img * fix-copies * add * add all other pipelines * add * add * add * add * add * make style * style + fix copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> * [docs] Improve safetensors docstring (#3368) * clarify safetensor docstring * fix typo * apply feedback * add: a warning message when using xformers in a PT 2.0 env. (#3365) * add: a warning message when using xformers in a PT 2.0 env. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> * StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322) * StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [docs] Adapt a model (#3326) * first draft * apply feedback * conv_in.weight thrown away * [docs] Load safetensors (#3333) * safetensors * apply feedback * apply feedback * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [Docs] Fix stable_diffusion.mdx typo (#3398) Fix typo in last code block. Correct "prommpts" to "prompt" * Support ControlNet v1.1 shuffle properly (#3340) * add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d63. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]> * [Tests] better determinism (#3374) * enable deterministic pytorch and cuda operations. * disable manual seeding. * make style && make quality for unet_2d tests. * enable determinism for the unet2dconditional model. * add CUBLAS_WORKSPACE_CONFIG for better reproducibility. * relax tolerance (very weird issue, though). * revert to torch manual_seed() where needed. * relax more tolerance. * better placement of the cuda variable and relax more tolerance. * enable determinism for 3d condition model. * relax tolerance. * add: determinism to alt_diffusion. * relax tolerance for alt diffusion. * dance diffusion. * dance diffusion is flaky. * test_dict_tuple_outputs_equivalent edit. * fix two more tests. * fix more ddim tests. * fix: argument. * change to diff in place of difference. * fix: test_save_load call. * test_save_load_float16 call. * fix: expected_max_diff * fix: paint by example. * relax tolerance. * add determinism to 1d unet model. * torch 2.0 regressions seem to be brutal * determinism to vae. * add reason to skipping. * up tolerance. * determinism to vq. * determinism to cuda. * determinism to the generic test pipeline file. * refactor general pipelines testing a bit. * determinism to alt diffusion i2i * up tolerance for alt diff i2i and audio diff * up tolerance. * determinism to audioldm * increase tolerance for audioldm lms. * increase tolerance for paint by paint. * increase tolerance for repaint. * determinism to cycle diffusion and sd 1. * relax tol for cycle diffusion 🚲 * relax tol for sd 1.0 * relax tol for controlnet. * determinism to img var. * relax tol for img variation. * tolerance to i2i sd * make style * determinism to inpaint. * relax tolerance for inpaiting. * determinism for inpainting legacy * relax tolerance. * determinism to instruct pix2pix * determinism to model editing. * model editing tolerance. * panorama determinism * determinism to pix2pix zero. * determinism to sag. * sd 2. determinism * sd. tolerance * disallow tf32 matmul. * relax tolerance is all you need. * make style and determinism to sd 2 depth * relax tolerance for depth. * tolerance to diffedit. * tolerance to sd 2 inpaint. * up tolerance. * determinism in upscaling. * tolerance in upscaler. * more tolerance relaxation. * determinism to v pred. * up tol for v_pred * unclip determinism * determinism to unclip img2img * determinism to text to video. * determinism to last set of tests * up tol. * vq cumsum doesn't have a deterministic kernel * relax tol * relax tol * [docs] Add transformers to install (#3388) add transformers to install * [deepspeed] partial ZeRO-3 support (#3076) * [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add omegaconf for tests (#3400) Add omegaconfg * Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353) * Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]> * Fix docker file (#3402) * up * up * fix: deepseepd_plugin retrieval from accelerate state (#3410) * [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399) * Add `sigmoid` beta scheduler to `DDPMScheduler` docstring * Add `sigmoid` beta scheduler to `RePaintScheduler` docstring --------- Co-authored-by: Patrick von Platen <[email protected]> * Don't install accelerate and transformers from source (#3415) * Don't install transformers and accelerate from source (#3414) * Improve fast tests (#3416) Update pr_tests.yml * attention refactor: the trilogy (#3387) * Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten * [Docs] update the PT 2.0 optimization doc with latest findings (#3370) * add: benchmarking stats for A100 and V100. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address patrick's comments. * add: rtx 4090 stats * ⚔ benchmark reports done * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * 3313 pr link. * add: plots. Co-authored-by: Pedro <[email protected]> * fix formattimg * update number percent. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Fix style rendering (#3433) * Fix style rendering. * Fix typo * unCLIP scheduler do not use note (#3417) * Replace deprecated command with environment file (#3409) Co-authored-by: Patrick von Platen <[email protected]> * fix warning message pipeline loading (#3446) * add stable diffusion tensorrt img2img pipeline (#3419) * add stable diffusion tensorrt img2img pipeline Signed-off-by: Asfiya Baig <[email protected]> * update docstrings Signed-off-by: Asfiya Baig <[email protected]> --------- Signed-off-by: Asfiya Baig <[email protected]> * Refactor controlnet and add img2img and inpaint (#3386) * refactor controlnet and add img2img and inpaint * First draft to get pipelines to work * make style * Fix more * Fix more * More tests * Fix more * Make inpainting work * make style and more tests * Apply suggestions from code review * up * make style * Fix imports * Fix more * Fix more * Improve examples * add test * Make sure import is correctly deprecated * Make sure everything works in compile mode * make sure authorship is correctly attributed * [Scheduler] DPM-Solver (++) Inverse Scheduler (#3335) * Add DPM-Solver Multistep Inverse Scheduler * Add draft tests for DiffEdit * Add inverse sde-dpmsolver steps to tune image diversity from inverted latents * Fix tests --------- Co-authored-by: Patrick von Platen <[email protected]> * [Docs] Fix incomplete docstring for resnet.py (#3438) Fix incomplete docstrings for resnet.py * fix tiled vae blend extent range (#3384) fix tiled vae bleand extent range * Small update to "Next steps" section (#3443) Small update to "Next steps" section: - PyTorch 2 is recommended. - Updated improvement figures. * Allow arbitrary aspect ratio in IFSuperResolutionPipeline (#3298) * Update pipeline_if_superresolution.py Allow arbitrary aspect ratio in IFSuperResolutionPipeline by using the input image shape * IFSuperResolutionPipeline: allow the user to override the height and width through the arguments * update IFSuperResolutionPipeline width/height doc string to match StableDiffusionInpaintPipeline conventions --------- Co-authored-by: Patrick von Platen <[email protected]> * Adding 'strength' parameter to StableDiffusionInpaintingPipeline (#3424) * Added explanation of 'strength' parameter * Added get_timesteps function which relies on new strength parameter * Added `strength` parameter which defaults to 1. * Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. * Added strength to check_inputs, throws error if out of range * Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. * WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values * Created a is_strength_max to initialise from pure random noise * Updated unit tests w.r.t new strength parameter + fixed new strength unit test * renamed parameter to avoid confusion with variable of same name * Updated regression values for new strength test - now passes * removed 'copied from' comment as this method is now different and divergent from the cpy * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <[email protected]> * Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false * Ensure backwards compatibility for prepare_latents * Fixed copy check typo * Fixes w.r.t backward compibility changes * make style * keep function argument ordering same for backwards compatibility in callees with copied from statements * make fix-copies --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]> * [WIP] Bugfix - Pipeline.from_pretrained is broken when the pipeline is partially downloaded (#3448) Added bugfix using f strings. * Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404) * gradient checkpointing bug fix * bug fix; changes for reviews * reformat * reformat --------- Co-authored-by: Patrick von Platen <[email protected]> * Make dreambooth lora more robust to orig unet (#3462) * Make dreambooth lora more robust to orig unet * up * Reduce peak VRAM by releasing large attention tensors (as soon as they're unnecessary) (#3463) Release large tensors in attention (as soon as they're no longer required). Reduces peak VRAM by nearly 2 GB for 1024x1024 (even after slicing), and the savings scale up with image size. * Add min snr to text2img lora training script (#3459) add min snr to text2img lora training script * Add inpaint lora scale support (#3460) * add inpaint lora scale support * add inpaint lora scale test --------- Co-authored-by: yueyang.hyy <[email protected]> * [From ckpt] Fix from_ckpt (#3466) * Correct from_ckpt * make style * Update full dreambooth script to work with IF (#3425) * Add IF dreambooth docs (#3470) * parameterize pass single args through tuple (#3477) * attend and excite tests disable determinism on the class level (#3478) * dreambooth docs torch.compile note (#3471) * dreambooth docs torch.compile note * Update examples/dreambooth/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/README.md Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * add: if entry in the dreambooth training docs. (#3472) * [docs] Textual inversion inference (#3473) * add textual inversion inference to docs * add to toctree --------- Co-authored-by: Sayak Paul <[email protected]> * [docs] Distributed inference (#3376) * distributed inference * move to inference section * apply feedback * update with split_between_processes * apply feedback * [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479) explicit view kernel size as number elements in flattened indices * mps & onnx tests rework (#3449) * Remove ONNX tests from PR. They are already a part of push_tests.yml. * Remove mps tests from PRs. They are already performed on push. * Fix workflow name for fast push tests. * Extract mps tests to a workflow. For better control/filtering. * Remove --extra-index-url from mps tests * Increase tolerance of mps test This test passes in my Mac (Ventura 13.3) but fails in the CI hardware (Ventura 13.2). I ran the local tests following the same steps that exist in the CI workflow. * Temporarily run mps tests on pr So we can test. * Revert "Temporarily run mps tests on pr" Tests passed, go back to running on push. --------- Signed-off-by: Asfiya Baig <[email protected]> Co-authored-by: Ilia Larchenko <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Horace He <[email protected]> Co-authored-by: Umar <[email protected]> Co-authored-by: Mylo <[email protected]> Co-authored-by: Markus Pobitzer <[email protected]> Co-authored-by: Cheng Lu <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Isamu Isozaki <[email protected]> Co-authored-by: Cesar Aybar <[email protected]> Co-authored-by: Will Rice <[email protected]> Co-authored-by: Adrià Arrufat <[email protected]> Co-authored-by: Sanchit Gandhi <[email protected]> Co-authored-by: At-sushi <[email protected]> Co-authored-by: Lucca Zenóbio <[email protected]> Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: pdoane <[email protected]> Co-authored-by: Will Berman <[email protected]> Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Rupert Menneer <[email protected]> Co-authored-by: sudowind <[email protected]> Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Laureηt <[email protected]> Co-authored-by: Jongwoo Han <[email protected]> Co-authored-by: asfiyab-nvidia <[email protected]> Co-authored-by: clarencechen <[email protected]> Co-authored-by: Laureηt <[email protected]> Co-authored-by: superlabs-dev <[email protected]> Co-authored-by: Dev Aggarwal <[email protected]> Co-authored-by: Vimarsh Chaturvedi <[email protected]> Co-authored-by: 7eu7d7 <[email protected]> Co-authored-by: cmdr2 <[email protected]> Co-authored-by: wfng92 <[email protected]> Co-authored-by: Glaceon-Hyy <[email protected]> Co-authored-by: yueyang.hyy <[email protected]>
1 parent 2b56e8c commit f3d570c

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/diffusers/models/modeling_utils.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
398398
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
399399
more information about each option see [designing a device
400400
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
401+
max_memory (`Dict`, *optional*):
402+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
403+
GPU and the available CPU RAM if unset.
404+
offload_folder (`str` or `os.PathLike`, *optional*):
405+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
406+
offload_state_dict (`bool`, *optional*):
407+
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
408+
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
409+
`True` when there is some disk offload.
401410
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
402411
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
403412
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
@@ -439,6 +448,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
439448
torch_dtype = kwargs.pop("torch_dtype", None)
440449
subfolder = kwargs.pop("subfolder", None)
441450
device_map = kwargs.pop("device_map", None)
451+
max_memory = kwargs.pop("max_memory", None)
452+
offload_folder = kwargs.pop("offload_folder", None)
453+
offload_state_dict = kwargs.pop("offload_state_dict", False)
442454
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
443455
variant = kwargs.pop("variant", None)
444456
use_safetensors = kwargs.pop("use_safetensors", None)
@@ -510,6 +522,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
510522
revision=revision,
511523
subfolder=subfolder,
512524
device_map=device_map,
525+
max_memory=max_memory,
526+
offload_folder=offload_folder,
527+
offload_state_dict=offload_state_dict,
513528
user_agent=user_agent,
514529
**kwargs,
515530
)
@@ -614,7 +629,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
614629
else: # else let accelerate handle loading and dispatching.
615630
# Load weights and dispatch according to the device_map
616631
# by default the device_map is None and the weights are loaded on the CPU
617-
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
632+
accelerate.load_checkpoint_and_dispatch(
633+
model,
634+
model_file,
635+
device_map,
636+
max_memory=max_memory,
637+
offload_folder=offload_folder,
638+
offload_state_dict=offload_state_dict,
639+
dtype=torch_dtype,
640+
)
618641

619642
loading_info = {
620643
"missing_keys": [],

src/diffusers/pipelines/pipeline_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ def load_sub_model(
354354
provider: Any,
355355
sess_options: Any,
356356
device_map: Optional[Union[Dict[str, torch.device], str]],
357+
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
358+
offload_folder: Optional[Union[str, os.PathLike]],
359+
offload_state_dict: bool,
357360
model_variants: Dict[str, str],
358361
name: str,
359362
from_flax: bool,
@@ -416,6 +419,9 @@ def load_sub_model(
416419
# This makes sure that the weights won't be initialized which significantly speeds up loading.
417420
if is_diffusers_model or is_transformers_model:
418421
loading_kwargs["device_map"] = device_map
422+
loading_kwargs["max_memory"] = max_memory
423+
loading_kwargs["offload_folder"] = offload_folder
424+
loading_kwargs["offload_state_dict"] = offload_state_dict
419425
loading_kwargs["variant"] = model_variants.pop(name, None)
420426
if from_flax:
421427
loading_kwargs["from_flax"] = True
@@ -808,6 +814,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
808814
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
809815
more information about each option see [designing a device
810816
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
817+
max_memory (`Dict`, *optional*):
818+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
819+
GPU and the available CPU RAM if unset.
820+
offload_folder (`str` or `os.PathLike`, *optional*):
821+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
822+
offload_state_dict (`bool`, *optional*):
823+
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
824+
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
825+
`True` when there is some disk offload.
811826
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
812827
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
813828
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
@@ -873,6 +888,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
873888
provider = kwargs.pop("provider", None)
874889
sess_options = kwargs.pop("sess_options", None)
875890
device_map = kwargs.pop("device_map", None)
891+
max_memory = kwargs.pop("max_memory", None)
892+
offload_folder = kwargs.pop("offload_folder", None)
893+
offload_state_dict = kwargs.pop("offload_state_dict", False)
876894
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
877895
variant = kwargs.pop("variant", None)
878896
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
@@ -1046,6 +1064,9 @@ def load_module(name, value):
10461064
provider=provider,
10471065
sess_options=sess_options,
10481066
device_map=device_map,
1067+
max_memory=max_memory,
1068+
offload_folder=offload_folder,
1069+
offload_state_dict=offload_state_dict,
10491070
model_variants=model_variants,
10501071
name=name,
10511072
from_flax=from_flax,

0 commit comments

Comments
 (0)