File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -985,7 +985,10 @@ def forward(self, x):
985
985
# save quantized state_dict
986
986
api (model )
987
987
988
- torch .save (model .state_dict (), "test.pth" )
988
+ # unique filename to avoid collision in parallel tests
989
+ ckpt_name = f"{ api .__name__ } _{ test_device } _{ test_dtype } _test.pth"
990
+
991
+ torch .save (model .state_dict (), ckpt_name )
989
992
# get quantized reference
990
993
model_qc = torch .compile (model , mode = "max-autotune" )
991
994
ref_q = model_qc (x ).detach ()
@@ -998,8 +1001,8 @@ def forward(self, x):
998
1001
api (model )
999
1002
1000
1003
# load quantized state_dict
1001
- state_dict = torch .load ("test.pth" , mmap = True )
1002
- os .remove ("test.pth" )
1004
+ state_dict = torch .load (ckpt_name , mmap = True )
1005
+ os .remove (ckpt_name )
1003
1006
1004
1007
model .load_state_dict (state_dict , assign = True )
1005
1008
model = model .to (device = test_device , dtype = test_dtype ).eval ()
You can’t perform that action at this time.
0 commit comments