Skip to content

Commit e5095b5

Browse files
author
luoyuan.luo
committed
Vectorise group_concurrent_contiguous in NumPy
1 parent 9858113 commit e5095b5

File tree

1 file changed

+9
-18
lines changed
  • python/sglang/srt/disaggregation/mooncake

1 file changed

+9
-18
lines changed

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,16 @@
3737
def group_concurrent_contiguous(
3838
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
3939
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
40-
src_groups = []
41-
dst_groups = []
42-
current_src = [src_indices[0]]
43-
current_dst = [dst_indices[0]]
44-
45-
for i in range(1, len(src_indices)):
46-
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
47-
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
48-
if src_contiguous and dst_contiguous:
49-
current_src.append(src_indices[i])
50-
current_dst.append(dst_indices[i])
51-
else:
52-
src_groups.append(current_src)
53-
dst_groups.append(current_dst)
54-
current_src = [src_indices[i]]
55-
current_dst = [dst_indices[i]]
40+
"""Vectorised NumPy implementation."""
41+
if src_indices.size == 0:
42+
return [], []
43+
44+
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
45+
src_groups = np.split(src_indices, brk)
46+
dst_groups = np.split(dst_indices, brk)
5647

57-
src_groups.append(current_src)
58-
dst_groups.append(current_dst)
48+
src_groups = [g.tolist() for g in src_groups]
49+
dst_groups = [g.tolist() for g in dst_groups]
5950

6051
return src_groups, dst_groups
6152

0 commit comments

Comments
 (0)