|
| 1 | +# Copyright (c) The DeepSpeed Contributors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | +""" |
| 6 | +This is a slimmed-down version of parallel_state.py (mpu) from Megatron-Deepspeed |
| 7 | +""" |
| 8 | + |
| 9 | +from deepspeed import comm as dist |
| 10 | + |
| 11 | +# Sequence parallel groups to handle both data and sequence parallelisms. |
| 12 | +# These groups are used to reduce gradients and shard parameters and optimizer stages for ZeRO. |
| 13 | +_SEQUENCE_PARALLEL_GROUP = None |
| 14 | +_SEQUENCE_DATA_PARALLEL_GROUP = None |
| 15 | + |
| 16 | + |
| 17 | +def initialize_sequence_parallel(sequence_parallel_size: int) -> None: |
| 18 | + """Initialize sequence parallel groups.""" |
| 19 | + |
| 20 | + assert dist.is_initialized() |
| 21 | + world_size: int = dist.get_world_size() |
| 22 | + |
| 23 | + if world_size < sequence_parallel_size: |
| 24 | + raise RuntimeError(f"world_size ({world_size}) is less than sequence_parallel_size {sequence_parallel_size}") |
| 25 | + |
| 26 | + if sequence_parallel_size <= 1: |
| 27 | + raise ValueError(f"sequence_parallel_size must be greater than 1, got {sequence_parallel_size}") |
| 28 | + |
| 29 | + if world_size % sequence_parallel_size != 0: |
| 30 | + raise RuntimeError( |
| 31 | + f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})") |
| 32 | + |
| 33 | + data_parallel_size: int = world_size // sequence_parallel_size |
| 34 | + sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size |
| 35 | + num_sequence_parallel_groups: int = world_size // sequence_parallel_size |
| 36 | + num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size |
| 37 | + |
| 38 | + rank = dist.get_rank() |
| 39 | + |
| 40 | + # Build the sequence parallel groups. |
| 41 | + global _SEQUENCE_PARALLEL_GROUP |
| 42 | + assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized" |
| 43 | + for i in range(num_sequence_parallel_groups): |
| 44 | + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) |
| 45 | + group = dist.new_group(ranks) |
| 46 | + if rank in ranks: |
| 47 | + _SEQUENCE_PARALLEL_GROUP = group |
| 48 | + |
| 49 | + # Build the sequence data parallel groups. |
| 50 | + global _SEQUENCE_DATA_PARALLEL_GROUP |
| 51 | + assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized" |
| 52 | + all_data_sequence_parallel_group_ranks = [] |
| 53 | + for i in range(num_sequence_data_parallel_groups): |
| 54 | + ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size) |
| 55 | + group = dist.new_group(ranks) |
| 56 | + all_data_sequence_parallel_group_ranks.append(list(ranks)) |
| 57 | + if rank in ranks: |
| 58 | + _SEQUENCE_DATA_PARALLEL_GROUP = group |
| 59 | + |
| 60 | + |
| 61 | +def get_sequence_parallel_group(): |
| 62 | + """Get the sequence parallel group the caller rank belongs to.""" |
| 63 | + assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" |
| 64 | + return _SEQUENCE_PARALLEL_GROUP |
| 65 | + |
| 66 | + |
| 67 | +def get_sequence_data_parallel_group(): |
| 68 | + """Get the sequence parallel group the caller rank belongs to.""" |
| 69 | + assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized" |
| 70 | + return _SEQUENCE_DATA_PARALLEL_GROUP |
| 71 | + |
| 72 | + |
| 73 | +def get_sequence_parallel_world_size(): |
| 74 | + """Return world size for the sequence parallel group.""" |
| 75 | + return dist.get_world_size(group=get_sequence_parallel_group()) |
| 76 | + |
| 77 | + |
| 78 | +def get_sequence_data_parallel_world_size(): |
| 79 | + """Return world size for the sequence parallel group.""" |
| 80 | + return dist.get_world_size(group=get_sequence_data_parallel_group()) |
| 81 | + |
| 82 | + |
| 83 | +def get_sequence_parallel_rank(): |
| 84 | + """Return my rank for the sequence parallel group.""" |
| 85 | + return dist.get_rank(group=get_sequence_parallel_group()) |
| 86 | + |
| 87 | + |
| 88 | +def get_sequence_data_parallel_rank(): |
| 89 | + """Return my rank for the sequence data parallel group.""" |
| 90 | + return dist.get_rank(group=get_sequence_data_parallel_group()) |
0 commit comments