diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 802cbd3d77..9c0e58723e 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -17,6 +17,7 @@ _get_sinc_resample_kernel, _stretch_waveform, ) +from torch.utils.cpp_extension import ROCM_HOME __all__ = [] @@ -495,7 +496,11 @@ def forward(self, melspec: Tensor) -> Tensor: if self.n_mels != n_mels: raise ValueError("Expected an input with {} mel bins. Found: {}".format(self.n_mels, n_mels)) - specgram = torch.relu(torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution) + if ROCM_HOME is not None: + solution = torch.linalg.pinv(self.fb.transpose(-1, -2)[None]) @ melspec + else: + solution = torch.linalg.lstsq(self.fb.transpose(-1, -2)[None], melspec, driver=self.driver).solution + specgram = torch.relu(solution) # unpack batch specgram = specgram.view(shape[:-2] + (freq, time))