diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 883399f855cc..201fd6d989bc 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -704,6 +704,7 @@ aten::bernoulli.Tensor aten::bernoulli.Tensor_out aten::bernoulli.float_out aten::bernoulli.out +aten::bernoulli.p aten::bernoulli_.Tensor aten::bernoulli_.float aten::bincount diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 5f8e75bde483..67159a010305 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4943,6 +4943,27 @@ class CPUReproTests(TestCase): torch.compile(converted_model)(*example_batch) check_metrics_vec_kernel_count(3) + def test_dropout(self): + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.dropout = eval(f"nn.Dropout{dim}d(p=0.5)") + + def forward(self, x): + torch.manual_seed(0) + x = self.dropout(x) + return x + + for dim in [1, 2, 3]: + model = Model(dim) + torch.manual_seed(0) + shape = [1, 3] + [256] * dim + x = torch.randn(*shape) + output = model(x) + c_model = torch.compile(model) + c_output = c_model(x) + self.assertTrue(torch.allclose(output, c_output)) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/test_decomp.py b/test/test_decomp.py index 0177c50ca7d8..c0a0b66300ca 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -608,17 +608,6 @@ class TestDecomp(TestCase): res = torch._decomp.decompositions.uniform(x, low=low, high=high) self.assertEqual(ref, res) - def test_bernoulli_p(self, device): - p = 0.3 - input_t = torch.rand(100, 100) - torch.manual_seed(123) - ref = torch.ops.aten.bernoulli.p(input_t, p) - torch.manual_seed(123) - res = torch._decomp.decompositions.bernoulli_p(input_t, p) - ref_p = ref.sum() / torch.prod(torch.tensor(ref.size())) - res_p = res.sum() / torch.prod(torch.tensor(res.size())) - self.assertEqual(ref_p, res_p, atol=0.06 * p, rtol=0.06) - def test_bernoulli_default(self, device): p = 0.3 p_t = p * torch.ones(5, 5) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index bd6b1cf88bb0..e38281afe00f 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -5117,21 +5117,6 @@ def bernoulli( return p -@register_decomposition(aten.bernoulli.p) -def bernoulli_p(self, p, *, generator: Optional[torch.Generator] = None): - if generator is None: - raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device) - else: - raw_p = torch.rand( - self.size(), - generator=generator, - dtype=self.float32, - device=self.device, - ) - p = (raw_p < p).to(self.dtype) - return p - - def isin_default(elements, test_elements, *, invert=False): if elements.numel() == 0: return torch.empty_like(elements, dtype=torch.bool)