diff --git a/test/test_jit.py b/test/test_jit.py index 7519830a4913..bd4790aeba33 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6827,6 +6827,80 @@ a") # type: (Optional[int]) -> bool return isinstance(x, int) + def test_dropout_eval(self): + class ScriptedConv2d(torch.jit.ScriptModule): + def __init__(self, in_channels, out_channels, **kwargs): + super(ScriptedConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + @torch.jit.script_method + def forward(self, x): + x = self.conv(x) + return x + x = self.bn(x) + return F.relu(x, inplace=True) + + class ScriptMod(torch.jit.ScriptModule): + def __init__(self): + super(ScriptMod, self).__init__() + self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2) + + @torch.jit.script_method + def forward(self, x): + x = self.Conv2d_1a_3x3(x) + return x + return F.dropout(x, training=self.training) + + class EagerConv2d(torch.nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super(EagerConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + return x + x = self.bn(x) + return F.relu(x, inplace=True) + + class EagerMod(torch.nn.Module): + def __init__(self): + super(EagerMod, self).__init__() + self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2) + + def forward(self, x): + x = self.Conv2d_1a_3x3(x) + return x + return F.dropout(x, training=self.training) + + script_input = torch.rand(4, 3, 299, 299) + eager_input = script_input.clone() + + with freeze_rng_state(): + script_mod = ScriptMod() + script_mod.eval() + script_output = script_mod(script_input) + + with freeze_rng_state(): + eager_mod = EagerMod() + eager_mod.eval() + eager_output = eager_mod(eager_input) + + self.assertEqual(script_output, eager_output) + + with freeze_rng_state(): + script_mod = ScriptMod() + script_mod.train() + script_output = script_mod(script_input) + + with freeze_rng_state(): + eager_mod = EagerMod() + eager_mod.train() + eager_output = eager_mod(eager_input) + + self.assertEqual(script_output, eager_output) + def test_python_call(self): def pyfunc(a): return a * 3.0 diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 62c7291dd5a7..1d363833cf98 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -292,7 +292,6 @@ class TestFuser(JitTestCase): a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) s = torch.jit.script(func, (a,)) - self.assertAllFused(s.graph_for(a,), except_for={'aten::div', 'prim::Constant'}) c = s(a) c.sum().backward() graph = backward_graph(s) diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 9d067f8b641c..1a2b27c058f0 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -1129,6 +1129,11 @@ const std::vector functions = { mask.bernoulli_(p1m) res = mask * input / p1m + if not train: + p1m = 1. + res = input + mask = torch.ones_like(input) + def backward(grad_output): use_cuda = grad_output.is_cuda if use_cuda: