[Inductor][CPU] disable bernoulli_p decomposition (#143460)

Fix https://github.com/pytorch/pytorch/issues/142853
`fallback_random=True` should cause RNG to match between compile/eager (by having compile fall back to eager for RNG ops), but the `bernoulli_p` decompose function is not fully consistent with the eager CPU implementation.
We remove the decomp and keep the version for` fallback_random=False`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143460
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
blzheng
2024-12-18 18:45:18 -08:00
committed by PyTorch MergeBot
parent fd8b217fcd
commit 288aa87383
4 changed files with 22 additions and 26 deletions

View File

@ -704,6 +704,7 @@ aten::bernoulli.Tensor
aten::bernoulli.Tensor_out
aten::bernoulli.float_out
aten::bernoulli.out
aten::bernoulli.p
aten::bernoulli_.Tensor
aten::bernoulli_.float
aten::bincount

View File

@ -4943,6 +4943,27 @@ class CPUReproTests(TestCase):
torch.compile(converted_model)(*example_batch)
check_metrics_vec_kernel_count(3)
def test_dropout(self):
class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.dropout = eval(f"nn.Dropout{dim}d(p=0.5)")
def forward(self, x):
torch.manual_seed(0)
x = self.dropout(x)
return x
for dim in [1, 2, 3]:
model = Model(dim)
torch.manual_seed(0)
shape = [1, 3] + [256] * dim
x = torch.randn(*shape)
output = model(x)
c_model = torch.compile(model)
c_output = c_model(x)
self.assertTrue(torch.allclose(output, c_output))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -608,17 +608,6 @@ class TestDecomp(TestCase):
res = torch._decomp.decompositions.uniform(x, low=low, high=high)
self.assertEqual(ref, res)
def test_bernoulli_p(self, device):
p = 0.3
input_t = torch.rand(100, 100)
torch.manual_seed(123)
ref = torch.ops.aten.bernoulli.p(input_t, p)
torch.manual_seed(123)
res = torch._decomp.decompositions.bernoulli_p(input_t, p)
ref_p = ref.sum() / torch.prod(torch.tensor(ref.size()))
res_p = res.sum() / torch.prod(torch.tensor(res.size()))
self.assertEqual(ref_p, res_p, atol=0.06 * p, rtol=0.06)
def test_bernoulli_default(self, device):
p = 0.3
p_t = p * torch.ones(5, 5)

View File

@ -5117,21 +5117,6 @@ def bernoulli(
return p
@register_decomposition(aten.bernoulli.p)
def bernoulli_p(self, p, *, generator: Optional[torch.Generator] = None):
if generator is None:
raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device)
else:
raw_p = torch.rand(
self.size(),
generator=generator,
dtype=self.float32,
device=self.device,
)
p = (raw_p < p).to(self.dtype)
return p
def isin_default(elements, test_elements, *, invert=False):
if elements.numel() == 0:
return torch.empty_like(elements, dtype=torch.bool)