Make test_jit more robust about compilation.

It's pretty easy to accidentally fail to actually compile
a JITed region, which means that we have accidentally failed
to have test coverage for a number of features.  This adds
a secret _assert_compiled kwarg, which will raise an error
if we don't actually hit the compiled codepath.

This is not intended to be user visible; we have some other
ideas for handle this case.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This commit is contained in:
Edward Z. Yang
2017-10-13 19:57:51 -07:00
committed by Edward Z. Yang
parent 6dc67aef17
commit f709199c49
2 changed files with 49 additions and 28 deletions

View File

@ -18,6 +18,21 @@ except ImportError:
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
class TestJit(TestCase):
maxDiff = None
@ -43,20 +58,6 @@ class TestJit(TestCase):
cx = Variable(torch.randn(3, 20).cuda())
module = nn.LSTMCell(10, 20).cuda() # Just to allocate weights with correct sizes
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace)
@ -65,6 +66,19 @@ class TestJit(TestCase):
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_run_lstm_fusion(self):
input = Variable(torch.randn(3, 10).cuda())
hx = Variable(torch.randn(3, 20).cuda())
cx = Variable(torch.randn(3, 20).cuda())
module = nn.LSTMCell(10, 20).cuda() # Just to allocate weights with correct sizes
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell)
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters(), _assert_compiled=True)
self.assertEqual(z, z2)
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_fusion_distribute(self):
def f(x, y):
@ -107,7 +121,7 @@ class TestJit(TestCase):
return torch.sigmoid(torch.tanh(x * (x + y)))
z = doit(x, y)
z2 = doit(x, y)
z2 = doit(x, y, _assert_compiled=True)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
@ -115,12 +129,12 @@ class TestJit(TestCase):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@torch.jit.compile
@torch.jit.compile(nderivs=0)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
z = doit(x, y)
z2 = doit(x, y)
z2 = doit(x, y, _assert_compiled=True)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
@ -170,7 +184,7 @@ class TestJit(TestCase):
lstm = MyLSTMCell(10, 20)
out = lstm(input, (hx, cx))
out2 = lstm(input, (hx, cx))
out2 = lstm(input, (hx, cx), _assert_compiled=True)
self.assertEqual(out, out2)
def test_autograd_closure(self):
@ -297,7 +311,7 @@ class TestJit(TestCase):
x = Variable(torch.randn(5, 5))
fn(x) # trace
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
fn(x) # create closure
fn(x, _assert_compiled=True) # create closure
def test_backward(self):
a = Variable(torch.randn(2, 2), requires_grad=True)
@ -358,7 +372,7 @@ class TestJit(TestCase):
x.grad.data.zero_()
# Run the trace
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
grad_x, = torch.autograd.grad(fn(x, _assert_compiled=True), (x,), create_graph=True)
grad_x.backward()
self.assertEqual(x.grad.data, x_grad)
@ -428,7 +442,7 @@ class TestJit(TestCase):
recursive_sum(fn(x)).backward()
self.assertTrue(fn.has_trace_for(x))
self.assertEqual(fn(x), expected_out)
self.assertEqual(fn(x, _assert_compiled=True), expected_out)
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
@ -444,7 +458,7 @@ class TestJit(TestCase):
fn = torch.jit.compile(fn)
fn(*x).backward()
self.assertTrue(fn.has_trace_for(*x))
self.assertEqual(fn(*x), expected_out)
self.assertEqual(fn(*x, _assert_compiled=True), expected_out)
def test_flags(self):
x = Variable(torch.randn(2, 2))
@ -487,6 +501,8 @@ class TestJit(TestCase):
self.assertFalse(fn.has_trace_for(x, y))
out = fn(x, y)
self.assertTrue(fn.has_trace_for(x, y))
out2 = fn(x, y, _assert_compiled=True)
self.assertEqual(out, out2)
def test_backward_flag_checks(self):
x = Variable(torch.randn(1), requires_grad=True)
@ -506,6 +522,8 @@ class TestJit(TestCase):
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
grad_x.backward(Variable(torch.ones(1), requires_grad=True))
# TODO: Test executing this
def test_python_ir(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@ -559,7 +577,7 @@ class TestJit(TestCase):
bn = MyBatchNorm2d(1)
x = Variable(torch.randn(5, 1))
z = bn(x)
z2 = bn(x)
z2 = bn(x, _assert_compiled=True)
self.assertEqual(z, z2)
def test_non_decorator_use_fails(self):
@ -587,7 +605,7 @@ class TestJit(TestCase):
# because we allocate a zero-filled new variable when we execute,
# and then *fill* it with the result
r1 = clinear(clinear(input, weights), weights)
r1 = clinear(clinear(input, weights), weights, _assert_compiled=True)
r2 = F.linear(F.linear(input, weights), weights)
self.assertEqual(r1, r2)
@ -614,7 +632,7 @@ class TestJit(TestCase):
z, _ = model(x, y)
z.sum().backward()
z, _ = model(x, y)
z, _ = model(x, y, _assert_compiled=True)
z.sum().backward()
@skipIfNoTorchVision

View File

@ -343,9 +343,11 @@ class _CompiledMixin(object):
# but since the logic is so complicated, testing code wouldn't benefit much
def __new_forward(self, *args, **kwargs):
force_trace = kwargs.pop("_force_trace", False)
assert_compiled = kwargs.pop("_assert_compiled", False)
if kwargs:
raise TypeError("Unrecognized keyword arguments: {}".format(kwargs.keys()))
if _JIT_DISABLE or not self.__enabled:
assert not assert_compiled
with _time(self.__name, "unoptimized", self.__time):
# Call to the saved old forward function
return self.__old_forward(*args)
@ -365,6 +367,7 @@ class _CompiledMixin(object):
out_struct = ktrace.out_struct
else:
# No compiled trace available. Run it by hand.
assert not assert_compiled
with _time(ktrace.name, "tracing", self.__time):
out_vars, out_struct = ktrace.add_trace(self.__old_forward,
args, in_vars, in_struct,
@ -613,10 +616,10 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
saved_args = _clone_inputs(args)
saved_state = copy.deepcopy(model.state_dict())
def run_fwd_bwd(args, force_trace=False):
def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
in_vars, _ = _flatten(args, model.state_dict(keep_vars=True).values())
# We use a special API to reset the trace and compile it from scratch.
out = model(*args, _force_trace=force_trace)
out = model(*args, _force_trace=force_trace, _assert_compiled=assert_compiled)
if not isinstance(out, tuple):
out = (out, )
if loss_fn == torch.sum and len(out) != 1:
@ -635,7 +638,7 @@ def verify(model, args, loss_fn=torch.sum, devices=None):
assert model.has_trace_for(*args)
model.load_state_dict(saved_state)
compiled_outs, compiled_grads = run_fwd_bwd(args)
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
_verify_equal(uncompiled_outs, compiled_outs)
_verify_equal(uncompiled_grads, compiled_grads)