Fix incorrect distribution of randperm with device mps (#104171)

Fixes #104170

As noted in the above issue it seems that the code for randperm basically boils down to:
`torch.argsort(torch.rand(size, device="mps"), dim = 0)`

However it seems like in the fused(?) pytorch version the type of tensor we were drawing `torch.rand(size, device="mps")` from was int64 with an inclusive(?) upper bound of 1. This caused everything to be sorted into two groups (if you drew 0 or 1) each monotonically ascending due to sort tie breaking.

One way to fix this is to  just generate the random tensor as float64s with an upper bound of 1.0 instead of int64s. An alternative to to just set the upper bound to max int 64.

~I choose the float64 one basically on a coin flip b/c I couldn't tell the original contributor's intent (due to mixed up upper bounds and type) but would be happy to change to use int64 and max int 64 as an upper bound instead if that's better.~

Edit on second thought I don't like using floats from 0.0 to 1.0 as there are fewer of them in that range than int64s from 0 to int 64 max_value. I also suspect integer math might be faster but need to benchmark this tomorrow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104171
Approved by: https://github.com/malfet
This commit is contained in:
Peter Stefek
2023-06-27 00:36:15 +00:00
committed by PyTorch MergeBot
parent 994b98b78b
commit d8a2e7461b

View File

@ -412,8 +412,8 @@ Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor&
};
return mps::random_mps_impl<int64_t>(result,
0.0,
1.0,
std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max(),
c10::nullopt,
c10::nullopt,
MPSGraphRandomDistributionUniform,