mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
dropout symbolic_script should respect the training flag (#20760)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20760 ghimport-source-id: eb667c3549a03a2fc01ffa0a2d3bc7e3a29b78e0 Reviewed By: jamesr66a Differential Revision: D15486511 Pulled By: suo fbshipit-source-id: 56ae930a01b0f6f4305a2a745135d4529b4a1ca0
This commit is contained in:
committed by
Facebook Github Bot
parent
bd53c8eb93
commit
62af37aa88
@ -6827,6 +6827,80 @@ a")
|
||||
# type: (Optional[int]) -> bool
|
||||
return isinstance(x, int)
|
||||
|
||||
def test_dropout_eval(self):
|
||||
class ScriptedConv2d(torch.jit.ScriptModule):
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(ScriptedConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
x = self.bn(x)
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
class ScriptMod(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(ScriptMod, self).__init__()
|
||||
self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
return x
|
||||
return F.dropout(x, training=self.training)
|
||||
|
||||
class EagerConv2d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(EagerConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
x = self.bn(x)
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
class EagerMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(EagerMod, self).__init__()
|
||||
self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
return x
|
||||
return F.dropout(x, training=self.training)
|
||||
|
||||
script_input = torch.rand(4, 3, 299, 299)
|
||||
eager_input = script_input.clone()
|
||||
|
||||
with freeze_rng_state():
|
||||
script_mod = ScriptMod()
|
||||
script_mod.eval()
|
||||
script_output = script_mod(script_input)
|
||||
|
||||
with freeze_rng_state():
|
||||
eager_mod = EagerMod()
|
||||
eager_mod.eval()
|
||||
eager_output = eager_mod(eager_input)
|
||||
|
||||
self.assertEqual(script_output, eager_output)
|
||||
|
||||
with freeze_rng_state():
|
||||
script_mod = ScriptMod()
|
||||
script_mod.train()
|
||||
script_output = script_mod(script_input)
|
||||
|
||||
with freeze_rng_state():
|
||||
eager_mod = EagerMod()
|
||||
eager_mod.train()
|
||||
eager_output = eager_mod(eager_input)
|
||||
|
||||
self.assertEqual(script_output, eager_output)
|
||||
|
||||
def test_python_call(self):
|
||||
def pyfunc(a):
|
||||
return a * 3.0
|
||||
|
@ -292,7 +292,6 @@ class TestFuser(JitTestCase):
|
||||
|
||||
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
|
||||
s = torch.jit.script(func, (a,))
|
||||
self.assertAllFused(s.graph_for(a,), except_for={'aten::div', 'prim::Constant'})
|
||||
c = s(a)
|
||||
c.sum().backward()
|
||||
graph = backward_graph(s)
|
||||
|
@ -1129,6 +1129,11 @@ const std::vector<std::string> functions = {
|
||||
mask.bernoulli_(p1m)
|
||||
res = mask * input / p1m
|
||||
|
||||
if not train:
|
||||
p1m = 1.
|
||||
res = input
|
||||
mask = torch.ones_like(input)
|
||||
|
||||
def backward(grad_output):
|
||||
use_cuda = grad_output.is_cuda
|
||||
if use_cuda:
|
||||
|
Reference in New Issue
Block a user