Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions verl/workers/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@

# Mindspeed must be imported before Megatron to ensure the related monkey patches take effect as expected
try:
from .mindspeed import MindspeedEngineWithLMHead, MindSpeedLLMEngineWithLMHead
from .mindspeed import MindspeedEngineWithLMHead, MindspeedEngineWithValueHead, MindSpeedLLMEngineWithLMHead

__all__ += ["MindspeedEngineWithLMHead", "MindSpeedLLMEngineWithLMHead"]
__all__ += ["MindspeedEngineWithLMHead", "MindspeedEngineWithValueHead", "MindSpeedLLMEngineWithLMHead"]
except ImportError:
MindspeedEngineWithLMHead = None
MindspeedEngineWithValueHead = None
MindSpeedLLMEngineWithLMHead = None

try:
from .megatron import MegatronEngine, MegatronEngineWithLMHead
from .megatron import MegatronEngine, MegatronEngineWithLMHead, MegatronEngineWithValueHead

__all__ += ["MegatronEngine", "MegatronEngineWithLMHead"]
__all__ += ["MegatronEngine", "MegatronEngineWithLMHead", "MegatronEngineWithValueHead"]
except ImportError:
MegatronEngine = None
MegatronEngineWithLMHead = None
4 changes: 2 additions & 2 deletions verl/workers/engine/megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
if not is_cuda_available and "TORCH_CUDA_ARCH_LIST" not in os.environ:
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0"

from .transformer_impl import MegatronEngine, MegatronEngineWithLMHead # noqa: E402
from .transformer_impl import MegatronEngine, MegatronEngineWithLMHead, MegatronEngineWithValueHead # noqa: E402

if not is_cuda_available:
del os.environ["TORCH_CUDA_ARCH_LIST"]

__all__ = ["MegatronEngine", "MegatronEngineWithLMHead"]
__all__ = ["MegatronEngine", "MegatronEngineWithLMHead", "MegatronEngineWithValueHead"]
4 changes: 2 additions & 2 deletions verl/workers/engine/mindspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .transformer_impl import MindspeedEngineWithLMHead, MindSpeedLLMEngineWithLMHead
from .transformer_impl import MindspeedEngineWithLMHead, MindspeedEngineWithValueHead, MindSpeedLLMEngineWithLMHead

__all__ = ["MindspeedEngineWithLMHead", "MindSpeedLLMEngineWithLMHead"]
__all__ = ["MindspeedEngineWithLMHead", "MindspeedEngineWithValueHead", "MindSpeedLLMEngineWithLMHead"]
26 changes: 25 additions & 1 deletion verl/workers/engine/mindspeed/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)

from ..base import EngineRegistry
from ..megatron import MegatronEngineWithLMHead
from ..megatron import MegatronEngineWithLMHead, MegatronEngineWithValueHead
from .utils import (
apply_patch,
gpt_model_provider,
Expand Down Expand Up @@ -66,6 +66,30 @@ def _init_device_mesh(self):
repatch(repatch_config)
super()._init_device_mesh()

@EngineRegistry.register(model_type="value_model", backend="megatron", device="npu")
class MindspeedEngineWithValueHead(MegatronEngineWithValueHead):
def __init__(
self,
model_config: HFModelConfig,
engine_config: McoreEngineConfig,
optimizer_config: McoreOptimizerConfig,
checkpoint_config: CheckpointConfig,
):
super().__init__(model_config, engine_config, optimizer_config, checkpoint_config)
Comment on lines +76 to +83
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The __init__ method in MindspeedEngineWithValueHead is redundant as it only calls super().__init__ with the exact same arguments as the base class MegatronEngineWithValueHead (and its ancestor MegatronEngine). In Python, this can be omitted to improve code clarity and maintainability.


def _init_device_mesh(self):
# repatch must happen before initialize_model_parallel so that
# initialize_model_parallel_cp_wrapper is in effect when the call is made.
# The initial MindSpeed patch pass sees context_parallel_size=1 (default) because
# verl passes CP size via hydra config rather than --context-parallel-size CLI arg,
# so the CP ring-rank initialization wrapper is not registered on the first pass.
if repatch is not None:
repatch_config = dict(self.engine_config.get("override_transformer_config", {}))
repatch_config.setdefault("use_flash_attn", True)
if self.engine_config.context_parallel_size > 1:
repatch_config["context_parallel_size"] = self.engine_config.context_parallel_size
repatch(repatch_config)
super()._init_device_mesh()
Comment thread
xiazhahe marked this conversation as resolved.

@EngineRegistry.register(model_type="language_model", backend="mindspeed_llm", device="npu")
class MindSpeedLLMEngineWithLMHead(MegatronEngineWithLMHead):
Expand Down
Loading