Skip to content

Commit 7a52769

Browse files
authored
[docs] More API stuff (huggingface#3835)
* clean up loaders * clean up rest of main class apis * apply feedback
1 parent eeefc1e commit 7a52769

8 files changed

+305
-398
lines changed

configuration_utils.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,17 @@ def __setitem__(self, name, value):
8181

8282
class ConfigMixin:
8383
r"""
84-
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
85-
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
86-
- [`~ConfigMixin.from_config`]
87-
- [`~ConfigMixin.save_config`]
84+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86+
saving classes that inherit from [`ConfigMixin`].
8887
8988
Class attributes:
9089
- **config_name** (`str`) -- A filename under which the config should stored when calling
9190
[`~ConfigMixin.save_config`] (should be overridden by parent class).
9291
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
9392
overridden by subclass).
9493
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
95-
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
94+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
9695
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
9796
subclass).
9897
"""
@@ -139,12 +138,12 @@ def __getattr__(self, name: str) -> Any:
139138

140139
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
141140
"""
142-
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
141+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
143142
[`~ConfigMixin.from_config`] class method.
144143
145144
Args:
146145
save_directory (`str` or `os.PathLike`):
147-
Directory where the configuration JSON file will be saved (will be created if it does not exist).
146+
Directory where the configuration JSON file is saved (will be created if it does not exist).
148147
"""
149148
if os.path.isfile(save_directory):
150149
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -164,15 +163,14 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
164163
165164
Parameters:
166165
config (`Dict[str, Any]`):
167-
A config dictionary from which the Python class will be instantiated. Make sure to only load
168-
configuration files of compatible classes.
166+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
167+
files of compatible classes.
169168
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
170169
Whether kwargs that are not consumed by the Python class should be returned or not.
171-
172170
kwargs (remaining dictionary of keyword arguments, *optional*):
173171
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
174-
`**kwargs` are directly passed to the underlying scheduler/model's `__init__` method and eventually
175-
overwrite same named arguments in `config`.
172+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
173+
overwrite the same named arguments in `config`.
176174
177175
Returns:
178176
[`ModelMixin`] or [`SchedulerMixin`]:
@@ -280,16 +278,16 @@ def load_config(
280278
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
281279
cached versions if they exist.
282280
resume_download (`bool`, *optional*, defaults to `False`):
283-
Whether or not to resume downloading the model weights and configuration files. If set to False, any
281+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
284282
incompletely downloaded files are deleted.
285283
proxies (`Dict[str, str]`, *optional*):
286284
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
287285
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
288286
output_loading_info(`bool`, *optional*, defaults to `False`):
289287
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
290-
local_files_only(`bool`, *optional*, defaults to `False`):
291-
Whether to only load local model weights and configuration files or not. If set to True, the model
292-
wont be downloaded from the Hub.
288+
local_files_only (`bool`, *optional*, defaults to `False`):
289+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
290+
won't be downloaded from the Hub.
293291
use_auth_token (`str` or *bool*, *optional*):
294292
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
295293
`diffusers-cli login` (stored in `~/.huggingface`) is used.
@@ -307,14 +305,6 @@ def load_config(
307305
`dict`:
308306
A dictionary of all the parameters stored in a JSON configuration file.
309307
310-
<Tip>
311-
312-
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
313-
`huggingface-cli login`. You can also activate the special
314-
["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a
315-
firewalled environment.
316-
317-
</Tip>
318308
"""
319309
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
320310
force_download = kwargs.pop("force_download", False)
@@ -536,10 +526,11 @@ def config(self) -> Dict[str, Any]:
536526

537527
def to_json_string(self) -> str:
538528
"""
539-
Serializes this instance to a JSON string.
529+
Serializes the configuration instance to a JSON string.
540530
541531
Returns:
542-
`str`: String containing all the attributes that make up this configuration instance in JSON format.
532+
`str`:
533+
String containing all the attributes that make up the configuration instance in JSON format.
543534
"""
544535
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
545536
config_dict["_class_name"] = self.__class__.__name__
@@ -560,11 +551,11 @@ def to_json_saveable(value):
560551

561552
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
562553
"""
563-
Save this instance to a JSON file.
554+
Save the configuration instance's parameters to a JSON file.
564555
565556
Args:
566557
json_file_path (`str` or `os.PathLike`):
567-
Path to the JSON file in which this configuration instance's parameters will be saved.
558+
Path to the JSON file to save a configuration instance's parameters.
568559
"""
569560
with open(json_file_path, "w", encoding="utf-8") as writer:
570561
writer.write(self.to_json_string())

image_processor.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,18 @@
2626

2727
class VaeImageProcessor(ConfigMixin):
2828
"""
29-
Image Processor for VAE
29+
Image processor for VAE.
3030
3131
Args:
3232
do_resize (`bool`, *optional*, defaults to `True`):
3333
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34-
`height` and `width` arguments from `preprocess` method
34+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
3535
vae_scale_factor (`int`, *optional*, defaults to `8`):
36-
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
37-
factor.
36+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
3837
resample (`str`, *optional*, defaults to `lanczos`):
3938
Resampling filter to use when resizing the image.
4039
do_normalize (`bool`, *optional*, defaults to `True`):
41-
Whether to normalize the image to [-1,1]
40+
Whether to normalize the image to [-1,1].
4241
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
4342
Whether to convert the images to RGB format.
4443
"""
@@ -75,7 +74,7 @@ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
7574
@staticmethod
7675
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
7776
"""
78-
Convert a PIL image or a list of PIL images to numpy arrays.
77+
Convert a PIL image or a list of PIL images to NumPy arrays.
7978
"""
8079
if not isinstance(images, list):
8180
images = [images]
@@ -87,7 +86,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd
8786
@staticmethod
8887
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
8988
"""
90-
Convert a numpy image to a pytorch tensor
89+
Convert a NumPy image to a PyTorch tensor.
9190
"""
9291
if images.ndim == 3:
9392
images = images[..., None]
@@ -98,22 +97,22 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
9897
@staticmethod
9998
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
10099
"""
101-
Convert a pytorch tensor to a numpy image
100+
Convert a PyTorch tensor to a NumPy image.
102101
"""
103102
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
104103
return images
105104

106105
@staticmethod
107106
def normalize(images):
108107
"""
109-
Normalize an image array to [-1,1]
108+
Normalize an image array to [-1,1].
110109
"""
111110
return 2.0 * images - 1.0
112111

113112
@staticmethod
114113
def denormalize(images):
115114
"""
116-
Denormalize an image array to [0,1]
115+
Denormalize an image array to [0,1].
117116
"""
118117
return (images / 2 + 0.5).clamp(0, 1)
119118

@@ -132,7 +131,7 @@ def resize(
132131
width: Optional[int] = None,
133132
) -> PIL.Image.Image:
134133
"""
135-
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
134+
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
136135
"""
137136
if height is None:
138137
height = image.height
@@ -152,7 +151,7 @@ def preprocess(
152151
width: Optional[int] = None,
153152
) -> torch.Tensor:
154153
"""
155-
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
154+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
156155
"""
157156
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
158157
if isinstance(image, supported_formats):
@@ -255,18 +254,17 @@ def postprocess(
255254

256255
class VaeImageProcessorLDM3D(VaeImageProcessor):
257256
"""
258-
Image Processor for VAE LDM3D.
257+
Image processor for VAE LDM3D.
259258
260259
Args:
261260
do_resize (`bool`, *optional*, defaults to `True`):
262261
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
263262
vae_scale_factor (`int`, *optional*, defaults to `8`):
264-
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
265-
factor.
263+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
266264
resample (`str`, *optional*, defaults to `lanczos`):
267265
Resampling filter to use when resizing the image.
268266
do_normalize (`bool`, *optional*, defaults to `True`):
269-
Whether to normalize the image to [-1,1]
267+
Whether to normalize the image to [-1,1].
270268
"""
271269

272270
config_name = CONFIG_NAME
@@ -284,7 +282,7 @@ def __init__(
284282
@staticmethod
285283
def numpy_to_pil(images):
286284
"""
287-
Convert a numpy image or a batch of images to a PIL image.
285+
Convert a NumPy image or a batch of images to a PIL image.
288286
"""
289287
if images.ndim == 3:
290288
images = images[None, ...]
@@ -310,7 +308,7 @@ def rgblike_to_depthmap(image):
310308

311309
def numpy_to_depth(self, images):
312310
"""
313-
Convert a numpy depth image or a batch of images to a PIL image.
311+
Convert a NumPy depth image or a batch of images to a PIL image.
314312
"""
315313
if images.ndim == 3:
316314
images = images[None, ...]

0 commit comments

Comments
 (0)