|
32 | 32 | from transformers import PretrainedConfig
|
33 | 33 |
|
34 | 34 | from verl.utils.torch_dtypes import PrecisionType
|
| 35 | +from verl.utils.model import normalize_model_name |
| 36 | +import verl.utils.megatron.tensor_parallel as tp_utils |
35 | 37 |
|
36 | 38 |
|
37 | 39 | def get_model_config(model):
|
@@ -619,3 +621,140 @@ def broadcast_str_from_megatron_pp(obj: Any):
|
619 | 621 | torch.distributed.broadcast_object_list(object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group())
|
620 | 622 |
|
621 | 623 | return obj_output[0]
|
| 624 | + |
| 625 | +def default_tp_concat_fn(layer_name_mapping, name, train_params, infer_params, model_config, convert_qkv_gate_up_by_simple_split=False): |
| 626 | + """ |
| 627 | + name: name of the parameter |
| 628 | + train_params: training parameters |
| 629 | + infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group |
| 630 | + model_config: huggingface model_config |
| 631 | + TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model |
| 632 | + definition so that it is model-agnostic. If the model doesn't implement this function, |
| 633 | + we can throw an error to force user disable TP HybridEngine. |
| 634 | + """ |
| 635 | + from megatron.core import mpu |
| 636 | + |
| 637 | + if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: |
| 638 | + # if the tensor is qkv, for each param on tp, split into q, k, v |
| 639 | + # concat q, k, v separately. |
| 640 | + q_lst = [] |
| 641 | + k_lst = [] |
| 642 | + v_lst = [] |
| 643 | + assert model_config.num_attention_heads % model_config.num_key_value_heads == 0 |
| 644 | + num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads |
| 645 | + assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0 |
| 646 | + kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) |
| 647 | + split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] |
| 648 | + for infer_param in infer_params: |
| 649 | + num_query_groups_per_partition = model_config.num_key_value_heads // mpu.get_tensor_model_parallel_world_size( |
| 650 | + ) |
| 651 | + for chunk in infer_param.chunk(num_query_groups_per_partition): |
| 652 | + split_size = [ |
| 653 | + kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, |
| 654 | + kv_size_per_tp // num_query_groups_per_partition, |
| 655 | + kv_size_per_tp // num_query_groups_per_partition |
| 656 | + ] |
| 657 | + q, k, v = chunk.split(split_size) |
| 658 | + q_lst.append(q) |
| 659 | + k_lst.append(k) |
| 660 | + v_lst.append(v) |
| 661 | + q = torch.cat(q_lst, dim=0) |
| 662 | + k = torch.cat(k_lst, dim=0) |
| 663 | + v = torch.cat(v_lst, dim=0) |
| 664 | + if not convert_qkv_gate_up_by_simple_split: |
| 665 | + infer_params = torch.cat((q, k, v), dim=0) |
| 666 | + else: |
| 667 | + infer_params = [q, k, v] |
| 668 | + |
| 669 | + elif layer_name_mapping.get("gate_proj_layer_name") in name: |
| 670 | + # if the tensor is gate and proj |
| 671 | + gate_lst = [] |
| 672 | + up_lst = [] |
| 673 | + for infer_param in infer_params: |
| 674 | + gate, up = infer_param.chunk(2) |
| 675 | + gate_lst.append(gate) |
| 676 | + up_lst.append(up) |
| 677 | + gate = torch.cat(gate_lst, dim=0) |
| 678 | + up = torch.cat(up_lst, dim=0) |
| 679 | + if not convert_qkv_gate_up_by_simple_split: |
| 680 | + infer_params = torch.cat((gate, up), dim=0) |
| 681 | + else: |
| 682 | + infer_params = [gate, up] |
| 683 | + |
| 684 | + else: |
| 685 | + # concat tensor |
| 686 | + infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params)) |
| 687 | + |
| 688 | + return infer_params |
| 689 | + |
| 690 | + |
| 691 | +def per_tensor_generator(actor_module, model_config, weight_converter, layer_name_mapping, convert_qkv_gate_up_by_simple_split=True): |
| 692 | + from megatron.core import parallel_state as mpu |
| 693 | + pp_rank = mpu.get_pipeline_model_parallel_rank() |
| 694 | + pp_size = mpu.get_pipeline_model_parallel_world_size() |
| 695 | + vpp_size = len(actor_module) |
| 696 | + all_gather_group = mpu.get_tensor_model_parallel_group() |
| 697 | + all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) |
| 698 | + |
| 699 | + def tensor_generator(): |
| 700 | + for scan_vpp_idx in range(vpp_size): |
| 701 | + yield from actor_module[scan_vpp_idx].named_parameters() |
| 702 | + |
| 703 | + # we need first make all rank get full model information |
| 704 | + meta_info = [] |
| 705 | + for scan_vpp_idx in range(vpp_size): |
| 706 | + for idx, (name, _) in enumerate(actor_module[scan_vpp_idx].named_parameters()): |
| 707 | + meta_info.append((pp_rank, scan_vpp_idx, idx, name)) |
| 708 | + |
| 709 | + obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() |
| 710 | + torch.distributed.all_gather_object( |
| 711 | + object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() |
| 712 | + ) |
| 713 | + layer_list_meta = [item for sublist in obj_spec_output for item in sublist] |
| 714 | + |
| 715 | + gen_func = tensor_generator() |
| 716 | + |
| 717 | + # lazy load tensor for full model |
| 718 | + for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: |
| 719 | + if cur_pp_rank == pp_rank: |
| 720 | + try: |
| 721 | + cur_name, cur_tensor = next(gen_func) |
| 722 | + except StopIteration: |
| 723 | + cur_name, cur_tensor = None, None |
| 724 | + cur_name = normalize_model_name( |
| 725 | + name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, model_config.num_hidden_layers |
| 726 | + ) |
| 727 | + else: |
| 728 | + cur_tensor, cur_name = None, None |
| 729 | + |
| 730 | + # pp broadcast model tensor and name |
| 731 | + cur_name = broadcast_str_from_megatron_pp(cur_name) |
| 732 | + broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor) |
| 733 | + |
| 734 | + # (xya): this is a hack to fix the name of the parameters |
| 735 | + while cur_name.startswith("module."): |
| 736 | + cur_name = cur_name[len("module.") :] |
| 737 | + |
| 738 | + # tp all gather |
| 739 | + if tp_utils.is_tensor_parallel_param(broad_pp_tensor): |
| 740 | + # allocate a new tensor with proper size |
| 741 | + if all_gather_group_size <= 1: |
| 742 | + infer_params = [broad_pp_tensor] |
| 743 | + else: |
| 744 | + infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] |
| 745 | + torch.distributed.all_gather( |
| 746 | + infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group() |
| 747 | + ) |
| 748 | + infer_params = default_tp_concat_fn( |
| 749 | + layer_name_mapping, cur_name, broad_pp_tensor, infer_params, model_config, convert_qkv_gate_up_by_simple_split |
| 750 | + ) |
| 751 | + else: |
| 752 | + infer_params = broad_pp_tensor |
| 753 | + |
| 754 | + |
| 755 | + if not isinstance(infer_params, list): |
| 756 | + infer_params = [infer_params] |
| 757 | + converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) |
| 758 | + |
| 759 | + yield from zip(converted_names, converted_params) |
| 760 | + |
0 commit comments