mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 00:54:52 +08:00
Make test_jit more robust about compilation.
It's pretty easy to accidentally fail to actually compile a JITed region, which means that we have accidentally failed to have test coverage for a number of features. This adds a secret _assert_compiled kwarg, which will raise an error if we don't actually hit the compiled codepath. This is not intended to be user visible; we have some other ideas for handle this case. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This commit is contained in:
committed by
Edward Z. Yang
parent
6dc67aef17
commit
f709199c49
@ -18,6 +18,21 @@ except ImportError:
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
|
||||
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
||||
hx, cx = hidden
|
||||
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
||||
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
ingate = F.sigmoid(ingate)
|
||||
forgetgate = F.sigmoid(forgetgate)
|
||||
cellgate = F.tanh(cellgate)
|
||||
outgate = F.sigmoid(outgate)
|
||||
|
||||
cy = (forgetgate * cx) + (ingate * cellgate)
|
||||
hy = outgate * F.tanh(cy)
|
||||
return hy, cy
|
||||
|
||||
|
||||
class TestJit(TestCase):
|
||||
maxDiff = None
|
||||
|
||||
@ -43,20 +58,6 @@ class TestJit(TestCase):
|
||||
cx = Variable(torch.randn(3, 20).cuda())
|
||||
module = nn.LSTMCell(10, 20).cuda() # Just to allocate weights with correct sizes
|
||||
|
||||
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
||||
hx, cx = hidden
|
||||
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
||||
|
||||
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
||||
ingate = F.sigmoid(ingate)
|
||||
forgetgate = F.sigmoid(forgetgate)
|
||||
cellgate = F.tanh(cellgate)
|
||||
outgate = F.sigmoid(outgate)
|
||||
|
||||
cy = (forgetgate * cx) + (ingate * cellgate)
|
||||
hy = outgate * F.tanh(cy)
|
||||
return hy, cy
|
||||
|
||||
trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
|
||||
torch._C._jit_pass_lint(trace)
|
||||
torch._C._jit_pass_onnx(trace)
|
||||
@ -65,6 +66,19 @@ class TestJit(TestCase):
|
||||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
|
||||
def test_run_lstm_fusion(self):
|
||||
input = Variable(torch.randn(3, 10).cuda())
|
||||
hx = Variable(torch.randn(3, 20).cuda())
|
||||
cx = Variable(torch.randn(3, 20).cuda())
|
||||
module = nn.LSTMCell(10, 20).cuda() # Just to allocate weights with correct sizes
|
||||
|
||||
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell)
|
||||
|
||||
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
||||
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters(), _assert_compiled=True)
|
||||
self.assertEqual(z, z2)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
|
||||
def test_fusion_distribute(self):
|
||||
def f(x, y):
|
||||
@ -107,7 +121,7 @@ class TestJit(TestCase):
|
||||
return torch.sigmoid(torch.tanh(x * (x + y)))
|
||||
|
||||
z = doit(x, y)
|
||||
z2 = doit(x, y)
|
||||
z2 = doit(x, y, _assert_compiled=True)
|
||||
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
||||
self.assertEqual(z, z2)
|
||||
|
||||
@ -115,12 +129,12 @@ class TestJit(TestCase):
|
||||
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
||||
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
||||
|
||||
@torch.jit.compile
|
||||
@torch.jit.compile(nderivs=0)
|
||||
def doit(x, y):
|
||||
return torch.sigmoid(torch.tanh(x * (x + y)))
|
||||
|
||||
z = doit(x, y)
|
||||
z2 = doit(x, y)
|
||||
z2 = doit(x, y, _assert_compiled=True)
|
||||
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
||||
self.assertEqual(z, z2)
|
||||
|
||||
@ -170,7 +184,7 @@ class TestJit(TestCase):
|
||||
lstm = MyLSTMCell(10, 20)
|
||||
|
||||
out = lstm(input, (hx, cx))
|
||||
out2 = lstm(input, (hx, cx))
|
||||
out2 = lstm(input, (hx, cx), _assert_compiled=True)
|
||||
self.assertEqual(out, out2)
|
||||
|
||||
def test_autograd_closure(self):
|
||||
@ -297,7 +311,7 @@ class TestJit(TestCase):
|
||||
x = Variable(torch.randn(5, 5))
|
||||
fn(x) # trace
|
||||
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
|
||||
fn(x) # create closure
|
||||
fn(x, _assert_compiled=True) # create closure
|
||||
|
||||
def test_backward(self):
|
||||
a = Variable(torch.randn(2, 2), requires_grad=True)
|
||||
@ -358,7 +372,7 @@ class TestJit(TestCase):
|
||||
x.grad.data.zero_()
|
||||
|
||||
# Run the trace
|
||||
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
||||
grad_x, = torch.autograd.grad(fn(x, _assert_compiled=True), (x,), create_graph=True)
|
||||
grad_x.backward()
|
||||
|
||||
self.assertEqual(x.grad.data, x_grad)
|
||||
@ -428,7 +442,7 @@ class TestJit(TestCase):
|
||||
|
||||
recursive_sum(fn(x)).backward()
|
||||
self.assertTrue(fn.has_trace_for(x))
|
||||
self.assertEqual(fn(x), expected_out)
|
||||
self.assertEqual(fn(x, _assert_compiled=True), expected_out)
|
||||
|
||||
def test_input_flatten(self):
|
||||
"""Check that inputs to traced functions are flattened"""
|
||||
@ -444,7 +458,7 @@ class TestJit(TestCase):
|
||||
fn = torch.jit.compile(fn)
|
||||
fn(*x).backward()
|
||||
self.assertTrue(fn.has_trace_for(*x))
|
||||
self.assertEqual(fn(*x), expected_out)
|
||||
self.assertEqual(fn(*x, _assert_compiled=True), expected_out)
|
||||
|
||||
def test_flags(self):
|
||||
x = Variable(torch.randn(2, 2))
|
||||
@ -487,6 +501,8 @@ class TestJit(TestCase):
|
||||
self.assertFalse(fn.has_trace_for(x, y))
|
||||
out = fn(x, y)
|
||||
self.assertTrue(fn.has_trace_for(x, y))
|
||||
out2 = fn(x, y, _assert_compiled=True)
|
||||
self.assertEqual(out, out2)
|
||||
|
||||
def test_backward_flag_checks(self):
|
||||
x = Variable(torch.randn(1), requires_grad=True)
|
||||
@ -506,6 +522,8 @@ class TestJit(TestCase):
|
||||
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
||||
grad_x.backward(Variable(torch.ones(1), requires_grad=True))
|
||||
|
||||
# TODO: Test executing this
|
||||
|
||||
def test_python_ir(self):
|
||||
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
||||
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
||||
@ -559,7 +577,7 @@ class TestJit(TestCase):
|
||||
bn = MyBatchNorm2d(1)
|
||||
x = Variable(torch.randn(5, 1))
|
||||
z = bn(x)
|
||||
z2 = bn(x)
|
||||
z2 = bn(x, _assert_compiled=True)
|
||||
self.assertEqual(z, z2)
|
||||
|
||||
def test_non_decorator_use_fails(self):
|
||||
@ -587,7 +605,7 @@ class TestJit(TestCase):
|
||||
# because we allocate a zero-filled new variable when we execute,
|
||||
# and then *fill* it with the result
|
||||
|
||||
r1 = clinear(clinear(input, weights), weights)
|
||||
r1 = clinear(clinear(input, weights), weights, _assert_compiled=True)
|
||||
r2 = F.linear(F.linear(input, weights), weights)
|
||||
|
||||
self.assertEqual(r1, r2)
|
||||
@ -614,7 +632,7 @@ class TestJit(TestCase):
|
||||
z, _ = model(x, y)
|
||||
z.sum().backward()
|
||||
|
||||
z, _ = model(x, y)
|
||||
z, _ = model(x, y, _assert_compiled=True)
|
||||
z.sum().backward()
|
||||
|
||||
@skipIfNoTorchVision
|
||||
|
||||
@ -343,9 +343,11 @@ class _CompiledMixin(object):
|
||||
# but since the logic is so complicated, testing code wouldn't benefit much
|
||||
def __new_forward(self, *args, **kwargs):
|
||||
force_trace = kwargs.pop("_force_trace", False)
|
||||
assert_compiled = kwargs.pop("_assert_compiled", False)
|
||||
if kwargs:
|
||||
raise TypeError("Unrecognized keyword arguments: {}".format(kwargs.keys()))
|
||||
if _JIT_DISABLE or not self.__enabled:
|
||||
assert not assert_compiled
|
||||
with _time(self.__name, "unoptimized", self.__time):
|
||||
# Call to the saved old forward function
|
||||
return self.__old_forward(*args)
|
||||
@ -365,6 +367,7 @@ class _CompiledMixin(object):
|
||||
out_struct = ktrace.out_struct
|
||||
else:
|
||||
# No compiled trace available. Run it by hand.
|
||||
assert not assert_compiled
|
||||
with _time(ktrace.name, "tracing", self.__time):
|
||||
out_vars, out_struct = ktrace.add_trace(self.__old_forward,
|
||||
args, in_vars, in_struct,
|
||||
@ -613,10 +616,10 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
|
||||
saved_args = _clone_inputs(args)
|
||||
saved_state = copy.deepcopy(model.state_dict())
|
||||
|
||||
def run_fwd_bwd(args, force_trace=False):
|
||||
def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
|
||||
in_vars, _ = _flatten(args, model.state_dict(keep_vars=True).values())
|
||||
# We use a special API to reset the trace and compile it from scratch.
|
||||
out = model(*args, _force_trace=force_trace)
|
||||
out = model(*args, _force_trace=force_trace, _assert_compiled=assert_compiled)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out, )
|
||||
if loss_fn == torch.sum and len(out) != 1:
|
||||
@ -635,7 +638,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
|
||||
assert model.has_trace_for(*args)
|
||||
|
||||
model.load_state_dict(saved_state)
|
||||
compiled_outs, compiled_grads = run_fwd_bwd(args)
|
||||
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
|
||||
|
||||
_verify_equal(uncompiled_outs, compiled_outs)
|
||||
_verify_equal(uncompiled_grads, compiled_grads)
|
||||
|
||||
Reference in New Issue
Block a user