|
95 | 95 | replace_return_docstrings, |
96 | 96 | strtobool, |
97 | 97 | ) |
98 | | -from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files |
| 98 | +from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files |
99 | 99 | from .utils.import_utils import ( |
100 | 100 | ENV_VARS_TRUE_VALUES, |
101 | 101 | is_sagemaker_mp_enabled, |
@@ -382,92 +382,6 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi |
382 | 382 | return False |
383 | 383 |
|
384 | 384 |
|
385 | | -def shard_checkpoint( |
386 | | - state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME |
387 | | -): |
388 | | - """ |
389 | | - Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a |
390 | | - given size. |
391 | | -
|
392 | | - The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no |
393 | | - optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the |
394 | | - limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], |
395 | | - [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. |
396 | | -
|
397 | | - <Tip warning={true}> |
398 | | -
|
399 | | - If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will |
400 | | - have a size greater than `max_shard_size`. |
401 | | -
|
402 | | - </Tip> |
403 | | -
|
404 | | - Args: |
405 | | - state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. |
406 | | - max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): |
407 | | - The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit |
408 | | - (like `"5MB"`). |
409 | | - weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): |
410 | | - The name of the model save file. |
411 | | - """ |
412 | | - logger.warning( |
413 | | - "Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using " |
414 | | - "split_torch_state_dict_into_shards from huggingface_hub library" |
415 | | - ) |
416 | | - max_shard_size = convert_file_size_to_int(max_shard_size) |
417 | | - |
418 | | - sharded_state_dicts = [{}] |
419 | | - last_block_size = 0 |
420 | | - total_size = 0 |
421 | | - storage_id_to_block = {} |
422 | | - |
423 | | - for key, weight in state_dict.items(): |
424 | | - # when bnb serialization is used the weights in the state dict can be strings |
425 | | - # check: https://github.com/huggingface/transformers/pull/24416 for more details |
426 | | - if isinstance(weight, str): |
427 | | - continue |
428 | | - else: |
429 | | - storage_id = id_tensor_storage(weight) |
430 | | - |
431 | | - # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` |
432 | | - if storage_id in storage_id_to_block and weight.device != torch.device("meta"): |
433 | | - block_id = storage_id_to_block[storage_id] |
434 | | - sharded_state_dicts[block_id][key] = weight |
435 | | - continue |
436 | | - |
437 | | - weight_size = weight.numel() * dtype_byte_size(weight.dtype) |
438 | | - # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one |
439 | | - # weight in the current shard. |
440 | | - if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: |
441 | | - sharded_state_dicts.append({}) |
442 | | - last_block_size = 0 |
443 | | - |
444 | | - sharded_state_dicts[-1][key] = weight |
445 | | - last_block_size += weight_size |
446 | | - total_size += weight_size |
447 | | - storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 |
448 | | - |
449 | | - # If we only have one shard, we return it |
450 | | - if len(sharded_state_dicts) == 1: |
451 | | - return {weights_name: sharded_state_dicts[0]}, None |
452 | | - |
453 | | - # Otherwise, let's build the index |
454 | | - weight_map = {} |
455 | | - shards = {} |
456 | | - for idx, shard in enumerate(sharded_state_dicts): |
457 | | - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") |
458 | | - shard_file = shard_file.replace( |
459 | | - ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" |
460 | | - ) |
461 | | - shards[shard_file] = shard |
462 | | - for key in shard.keys(): |
463 | | - weight_map[key] = shard_file |
464 | | - |
465 | | - # Add the metadata |
466 | | - metadata = {"total_size": total_size} |
467 | | - index = {"metadata": metadata, "weight_map": weight_map} |
468 | | - return shards, index |
469 | | - |
470 | | - |
471 | 385 | def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): |
472 | 386 | """ |
473 | 387 | This is the same as |
|
0 commit comments