|
2 | 2 |
|
3 | 3 | import random |
4 | 4 |
|
| 5 | +import numpy as np |
5 | 6 | import pytest |
6 | 7 | import torch |
7 | 8 |
|
|
101 | 102 | "device": ["xpu:0"], |
102 | 103 | "kv_cache_dtype": KV_CACHE_DTYPE, |
103 | 104 | }, |
| 105 | + "test_swap_blocks_batch": { |
| 106 | + "direction": [("cpu", "xpu")], |
| 107 | + "device": ["xpu:0"], |
| 108 | + }, |
| 109 | + "test_swap_blocks_batch_h2d_mutation_race": { |
| 110 | + "device": ["xpu:0"], |
| 111 | + }, |
104 | 112 | } |
105 | 113 |
|
106 | 114 |
|
@@ -948,3 +956,149 @@ def test_swap_blocks_mla( |
948 | 956 | msg=f"Block {src} from src should have been swapped to block " |
949 | 957 | f"{dst} in dst_cache.", |
950 | 958 | ) |
| 959 | + |
| 960 | + |
| 961 | +# --------------------------------------------------------------------------- |
| 962 | +# swap_blocks_batch tests |
| 963 | +# --------------------------------------------------------------------------- |
| 964 | + |
| 965 | + |
| 966 | +def _build_batch_args( |
| 967 | + src_cache: torch.Tensor, |
| 968 | + dst_cache: torch.Tensor, |
| 969 | + block_mapping: list[tuple[int, int]], |
| 970 | + block_size_in_bytes: int, |
| 971 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 972 | + """Build (src_ptrs, dst_ptrs, sizes) tensors for swap_blocks_batch.""" |
| 973 | + n = len(block_mapping) |
| 974 | + src_arr = np.empty(n, dtype=np.uint64) |
| 975 | + dst_arr = np.empty(n, dtype=np.uint64) |
| 976 | + sz_arr = np.full(n, block_size_in_bytes, dtype=np.uint64) |
| 977 | + |
| 978 | + src_base = src_cache.data_ptr() |
| 979 | + dst_base = dst_cache.data_ptr() |
| 980 | + stride = src_cache.stride(0) * src_cache.element_size() |
| 981 | + |
| 982 | + for i, (sb, db) in enumerate(block_mapping): |
| 983 | + src_arr[i] = src_base + sb * stride |
| 984 | + dst_arr[i] = dst_base + db * stride |
| 985 | + |
| 986 | + return (torch.from_numpy(src_arr), torch.from_numpy(dst_arr), |
| 987 | + torch.from_numpy(sz_arr)) |
| 988 | + |
| 989 | + |
| 990 | +@pytest.mark.parametrize("direction", COPYING_DIRECTION) |
| 991 | +@pytest.mark.parametrize("device", DEVICES) |
| 992 | +@torch.inference_mode() |
| 993 | +def test_swap_blocks_batch( |
| 994 | + direction: tuple[str, str], |
| 995 | + device: str, |
| 996 | +) -> None: |
| 997 | + """Test swap_blocks_batch for H2D, D2H and D2D directions.""" |
| 998 | + num_mappings = 64 |
| 999 | + num_heads = 8 |
| 1000 | + head_size = 64 |
| 1001 | + block_size = 8 |
| 1002 | + num_blocks = 256 |
| 1003 | + dtype = torch.bfloat16 |
| 1004 | + seed = 0 |
| 1005 | + |
| 1006 | + seed_everything(seed) |
| 1007 | + |
| 1008 | + src_device = device if direction[0] == "xpu" else "cpu" |
| 1009 | + dst_device = device if direction[1] == "xpu" else "cpu" |
| 1010 | + |
| 1011 | + src_blocks = random.sample(range(num_blocks), num_mappings) |
| 1012 | + if src_device == dst_device: |
| 1013 | + remaining = list(set(range(num_blocks)) - set(src_blocks)) |
| 1014 | + dst_blocks = random.sample(remaining, num_mappings) |
| 1015 | + else: |
| 1016 | + dst_blocks = random.sample(range(num_blocks), num_mappings) |
| 1017 | + block_mapping = list(zip(src_blocks, dst_blocks)) |
| 1018 | + |
| 1019 | + src_key, src_val = create_kv_caches_with_random(num_blocks, block_size, 1, |
| 1020 | + num_heads, head_size, |
| 1021 | + "auto", dtype, seed, |
| 1022 | + src_device) |
| 1023 | + dst_key, dst_val = create_kv_caches_with_random(num_blocks, block_size, 1, |
| 1024 | + num_heads, head_size, |
| 1025 | + "auto", dtype, seed, |
| 1026 | + dst_device) |
| 1027 | + |
| 1028 | + src_key_clone = src_key[0].clone() |
| 1029 | + src_val_clone = src_val[0].clone() |
| 1030 | + |
| 1031 | + block_size_in_bytes = src_key[0].element_size() * src_key[0].stride(0) |
| 1032 | + |
| 1033 | + # Build batch args and call |
| 1034 | + for src_cache, dst_cache in [(src_key[0], dst_key[0]), |
| 1035 | + (src_val[0], dst_val[0])]: |
| 1036 | + sp, dp, sz = _build_batch_args(src_cache, dst_cache, block_mapping, |
| 1037 | + block_size_in_bytes) |
| 1038 | + ops.swap_blocks_batch(sp, dp, sz) |
| 1039 | + |
| 1040 | + torch.xpu.synchronize() |
| 1041 | + |
| 1042 | + for sb, db in block_mapping: |
| 1043 | + torch.testing.assert_close(src_key_clone[sb].cpu(), |
| 1044 | + dst_key[0][db].cpu()) |
| 1045 | + torch.testing.assert_close(src_val_clone[sb].cpu(), |
| 1046 | + dst_val[0][db].cpu()) |
| 1047 | + |
| 1048 | + |
| 1049 | +@torch.inference_mode() |
| 1050 | +def test_swap_blocks_batch_h2d_mutation_race() -> None: |
| 1051 | + """Verify staging buffer protects against caller mutation for H2D batch.""" |
| 1052 | + num_mappings = 256 |
| 1053 | + num_heads = 8 |
| 1054 | + head_size = 128 |
| 1055 | + block_size = 32 |
| 1056 | + num_blocks = 512 |
| 1057 | + dtype = torch.bfloat16 |
| 1058 | + seed = 0 |
| 1059 | + |
| 1060 | + seed_everything(seed) |
| 1061 | + |
| 1062 | + src_blocks = random.sample(range(num_blocks), num_mappings) |
| 1063 | + dst_blocks = random.sample(range(num_blocks), num_mappings) |
| 1064 | + block_mapping = list(zip(src_blocks, dst_blocks)) |
| 1065 | + |
| 1066 | + # Source: pinned CPU memory |
| 1067 | + src_key, src_val = create_kv_caches_with_pinned(num_blocks, block_size, 1, |
| 1068 | + num_heads, head_size, |
| 1069 | + "auto", dtype, seed, "cpu") |
| 1070 | + assert src_key[0].is_pinned() |
| 1071 | + |
| 1072 | + # Destination: XPU |
| 1073 | + dst_key, dst_val = create_kv_caches_with_random(num_blocks, block_size, 1, |
| 1074 | + num_heads, head_size, |
| 1075 | + "auto", dtype, seed) |
| 1076 | + |
| 1077 | + src_key_clone = src_key[0].clone() |
| 1078 | + src_val_clone = src_val[0].clone() |
| 1079 | + |
| 1080 | + block_size_in_bytes = src_key[0].element_size() * src_key[0].stride(0) |
| 1081 | + |
| 1082 | + for src_cache, dst_cache in [(src_key[0], dst_key[0]), |
| 1083 | + (src_val[0], dst_val[0])]: |
| 1084 | + sp, dp, sz = _build_batch_args(src_cache, dst_cache, block_mapping, |
| 1085 | + block_size_in_bytes) |
| 1086 | + ops.swap_blocks_batch(sp, dp, sz) |
| 1087 | + |
| 1088 | + # Immediately mutate source — should not affect destination. |
| 1089 | + src_key[0].fill_(0) |
| 1090 | + src_val[0].fill_(0) |
| 1091 | + |
| 1092 | + torch.xpu.synchronize() |
| 1093 | + |
| 1094 | + for sb, db in block_mapping: |
| 1095 | + torch.testing.assert_close( |
| 1096 | + src_key_clone[sb].cpu(), |
| 1097 | + dst_key[0][db].cpu(), |
| 1098 | + msg=f"Key block {sb}→{db} corrupted by post-call mutation", |
| 1099 | + ) |
| 1100 | + torch.testing.assert_close( |
| 1101 | + src_val_clone[sb].cpu(), |
| 1102 | + dst_val[0][db].cpu(), |
| 1103 | + msg=f"Value block {sb}→{db} corrupted by post-call mutation", |
| 1104 | + ) |
0 commit comments