@@ -59,7 +59,7 @@ def setUpClass(cls):
59
59
"""Setup once for all tests: Load model and prepare test data."""
60
60
61
61
# Get environment variables (if set), otherwise use default values
62
- mimi_weight = os .getenv ("MIMI_WEIGHT" , None )
62
+ cls . mimi_weight = os .getenv ("MIMI_WEIGHT" , None )
63
63
hf_repo = os .getenv ("HF_REPO" , loaders .DEFAULT_REPO )
64
64
device = "cuda" if torch .cuda .device_count () else "cpu"
65
65
@@ -75,15 +75,15 @@ def seed_all(seed):
75
75
76
76
seed_all (42424242 )
77
77
78
- if mimi_weight is None :
78
+ if cls . mimi_weight is None :
79
79
try :
80
- mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
80
+ cls . mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
81
81
except :
82
- mimi_weight = hf_hub_download (
82
+ cls . mimi_weight = hf_hub_download (
83
83
hf_repo , loaders .MIMI_NAME , proxies = proxies
84
84
)
85
85
86
- cls .mimi = loaders .get_mimi (mimi_weight , device )
86
+ cls .mimi = loaders .get_mimi (cls . mimi_weight , device )
87
87
cls .device = device
88
88
cls .sample_pcm , cls .sample_sr = read_mp3_from_url (
89
89
"https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
@@ -182,8 +182,8 @@ def forward(self, x):
182
182
return out
183
183
184
184
emb_input = torch .rand (1 , 1 , 512 , device = "cpu" )
185
-
186
- mimi_decode = MimiDecode (self . mimi )
185
+ mimi_cpu = loaders . get_mimi ( self . mimi_weight , "cpu" )
186
+ mimi_decode = MimiDecode (mimi_cpu )
187
187
mimi_decode .eval ()
188
188
mimi_decode (emb_input )
189
189
@@ -225,7 +225,9 @@ def forward(self, x):
225
225
# Compare results
226
226
sqnr = compute_sqnr (eager_res , res [0 ])
227
227
print (f"SQNR: { sqnr } " )
228
- torch .testing .assert_close (eager_res , res [0 ], atol = 4e-3 , rtol = 1e-3 )
228
+ # Don't check for exact equality, but check that the SQNR is high enough
229
+ # torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3)
230
+ self .assertGreater (sqnr , 25.0 )
229
231
230
232
231
233
if __name__ == "__main__" :
0 commit comments