Make a transformer that transforms `x.transpose(1, 2)` into `torch.einsum('abc...->acb...')` in order to then have these operations fused with the rest of the einsums