I think the Pytorch distributed API changed somewhere after `2.0.0`, so our code probably doesn't work for all versions right now