Skip to content

Commit 4c3c1fd

Browse files
committed
per-test sharding. avoid name collision
1 parent 60dbe86 commit 4c3c1fd

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,10 @@ def forward(self, x):
985985
# save quantized state_dict
986986
api(model)
987987

988-
torch.save(model.state_dict(), "test.pth")
988+
# unique filename to avoid collision in parallel tests
989+
ckpt_name = f"{api.__name__}_{test_device}_{test_dtype}_test.pth"
990+
991+
torch.save(model.state_dict(), ckpt_name)
989992
# get quantized reference
990993
model_qc = torch.compile(model, mode="max-autotune")
991994
ref_q = model_qc(x).detach()
@@ -998,8 +1001,8 @@ def forward(self, x):
9981001
api(model)
9991002

10001003
# load quantized state_dict
1001-
state_dict = torch.load("test.pth", mmap=True)
1002-
os.remove("test.pth")
1004+
state_dict = torch.load(ckpt_name, mmap=True)
1005+
os.remove(ckpt_name)
10031006

10041007
model.load_state_dict(state_dict, assign=True)
10051008
model = model.to(device=test_device, dtype=test_dtype).eval()

0 commit comments

Comments
 (0)