Skip to content

Commit 3bbe99e

Browse files
authored
[Intel HPU] Enable dist sampler on intel hpu platform (#4445)
1 parent 4251ac5 commit 3bbe99e

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def forward_intel_hpu(
416416

417417
if next_tokens.shape[0] != max_batch:
418418
dim = next_tokens.shape[-1]
419-
tmp_tokens = paddle.full((max_batch, dim), -1, dtype=next_tokens.dtype)
419+
tmp_tokens = paddle.full((max_batch, dim), -1 if local_rank == 0 else 0, dtype=next_tokens.dtype)
420420
tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :])
421421
return tmp_tokens
422422

fastdeploy/worker/hpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from paddleformers.utils.log import logger
2525

2626
from fastdeploy.config import FDConfig
27+
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
2728
from fastdeploy.engine.request import Request
2829

2930
# from fastdeploy.spec_decode import MTPProposer, NgramProposer
@@ -944,7 +945,7 @@ def _dummy_run(
944945
if self.parallel_config.tensor_parallel_size > 1:
945946
dtype = sampled_token_ids.dtype
946947
sampled_token_ids = sampled_token_ids.to("float32")
947-
paddle.distributed.broadcast(sampled_token_ids, 0)
948+
tensor_model_parallel_all_reduce_custom(sampled_token_ids)
948949
sampled_token_ids = sampled_token_ids.to(dtype)
949950

950951
# 6. post process
@@ -1272,7 +1273,7 @@ class at the server level, which is too granular for ModelRunner.
12721273
if self.parallel_config.tensor_parallel_size > 1:
12731274
dtype = sampled_token_ids.dtype
12741275
sampled_token_ids = sampled_token_ids.to("float32")
1275-
paddle.distributed.broadcast(sampled_token_ids, 0)
1276+
tensor_model_parallel_all_reduce_custom(sampled_token_ids)
12761277
sampled_token_ids = sampled_token_ids.to(dtype)
12771278
if self.is_hpu_perf_breakdown_sync_mode:
12781279
sampled_token_ids.cpu()

0 commit comments

Comments
 (0)