dropout symbolic_script should respect the training flag (#20760)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20760
ghimport-source-id: eb667c3549a03a2fc01ffa0a2d3bc7e3a29b78e0

Reviewed By: jamesr66a

Differential Revision: D15486511

Pulled By: suo

fbshipit-source-id: 56ae930a01b0f6f4305a2a745135d4529b4a1ca0
This commit is contained in:
Michael Suo
2019-05-23 18:06:19 -07:00
committed by Facebook Github Bot
parent bd53c8eb93
commit 62af37aa88
3 changed files with 79 additions and 1 deletions

View File

@ -6827,6 +6827,80 @@ a")
# type: (Optional[int]) -> bool
return isinstance(x, int)
def test_dropout_eval(self):
class ScriptedConv2d(torch.jit.ScriptModule):
def __init__(self, in_channels, out_channels, **kwargs):
super(ScriptedConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
@torch.jit.script_method
def forward(self, x):
x = self.conv(x)
return x
x = self.bn(x)
return F.relu(x, inplace=True)
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super(ScriptMod, self).__init__()
self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2)
@torch.jit.script_method
def forward(self, x):
x = self.Conv2d_1a_3x3(x)
return x
return F.dropout(x, training=self.training)
class EagerConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(EagerConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
return x
x = self.bn(x)
return F.relu(x, inplace=True)
class EagerMod(torch.nn.Module):
def __init__(self):
super(EagerMod, self).__init__()
self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2)
def forward(self, x):
x = self.Conv2d_1a_3x3(x)
return x
return F.dropout(x, training=self.training)
script_input = torch.rand(4, 3, 299, 299)
eager_input = script_input.clone()
with freeze_rng_state():
script_mod = ScriptMod()
script_mod.eval()
script_output = script_mod(script_input)
with freeze_rng_state():
eager_mod = EagerMod()
eager_mod.eval()
eager_output = eager_mod(eager_input)
self.assertEqual(script_output, eager_output)
with freeze_rng_state():
script_mod = ScriptMod()
script_mod.train()
script_output = script_mod(script_input)
with freeze_rng_state():
eager_mod = EagerMod()
eager_mod.train()
eager_output = eager_mod(eager_input)
self.assertEqual(script_output, eager_output)
def test_python_call(self):
def pyfunc(a):
return a * 3.0

View File

@ -292,7 +292,6 @@ class TestFuser(JitTestCase):
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
s = torch.jit.script(func, (a,))
self.assertAllFused(s.graph_for(a,), except_for={'aten::div', 'prim::Constant'})
c = s(a)
c.sum().backward()
graph = backward_graph(s)

View File

@ -1129,6 +1129,11 @@ const std::vector<std::string> functions = {
mask.bernoulli_(p1m)
res = mask * input / p1m
if not train:
p1m = 1.
res = input
mask = torch.ones_like(input)
def backward(grad_output):
use_cuda = grad_output.is_cuda
if use_cuda: