mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Re-enable and fix most JIT tests
This commit is contained in:
committed by
Soumith Chintala
parent
61afb0d519
commit
fa0f3cf98a
@ -16,28 +16,31 @@ graph(%1 : Double(10, 3, 224, 224)
|
||||
%16 : Double(1000, 4096)
|
||||
%17 : Double(1000)) {
|
||||
%19 : Double(10, 64, 55, 55), %20 : Handle = CppOp[ConvForward](%1, %2, %3), uses = [[%21.i0], []];
|
||||
%22 : Double(10, 64, 55, 55), %23 : Handle = ^Threshold(0, 0, True)(%19), uses = [[%24.i0], []];
|
||||
%25 : Double(10, 64, 27, 27), %26 : Long(10, 64, 27, 27), %27 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%22), uses = [[%28.i0], [], []];
|
||||
%29 : Double(10, 192, 27, 27), %30 : Handle = CppOp[ConvForward](%25, %4, %5), uses = [[%31.i0], []];
|
||||
%32 : Double(10, 192, 27, 27), %33 : Handle = ^Threshold(0, 0, True)(%29), uses = [[%34.i0], []];
|
||||
%35 : Double(10, 192, 13, 13), %36 : Long(10, 192, 13, 13), %37 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%32), uses = [[%38.i0], [], []];
|
||||
%39 : Double(10, 384, 13, 13), %40 : Handle = CppOp[ConvForward](%35, %6, %7), uses = [[%41.i0], []];
|
||||
%42 : Double(10, 384, 13, 13), %43 : Handle = ^Threshold(0, 0, True)(%39), uses = [[%44.i0], []];
|
||||
%45 : Double(10, 256, 13, 13), %46 : Handle = CppOp[ConvForward](%42, %8, %9), uses = [[%47.i0], []];
|
||||
%48 : Double(10, 256, 13, 13), %49 : Handle = ^Threshold(0, 0, True)(%45), uses = [[%50.i0], []];
|
||||
%51 : Double(10, 256, 13, 13), %52 : Handle = CppOp[ConvForward](%48, %10, %11), uses = [[%53.i0], []];
|
||||
%54 : Double(10, 256, 13, 13), %55 : Handle = ^Threshold(0, 0, True)(%51), uses = [[%56.i0], []];
|
||||
%57 : Double(10, 256, 6, 6), %58 : Long(10, 256, 6, 6), %59 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%54), uses = [[%60.i0], [], []];
|
||||
%61 : Double(10, 9216), %62 : Handle = ^View((10, 9216))(%57), uses = [[%63.i0], []];
|
||||
%64 : Double(10, 9216), %65 : Handle = ^Dropout(0.5, True, False)(%61), uses = [[%68.i1], []];
|
||||
%67 : Double(9216!, 4096!) = ^Transpose(0, 1)(%12), uses = [[%68.i2]];
|
||||
%69 : Double(10, 4096), %70 : Handle = ^Addmm(1, 1, False)(%13, %64, %67), uses = [[%71.i0], []];
|
||||
%72 : Double(10, 4096), %73 : Handle = ^Threshold(0, 0, True)(%69), uses = [[%74.i0], []];
|
||||
%75 : Double(10, 4096), %76 : Handle = ^Dropout(0.5, True, False)(%72), uses = [[%79.i1], []];
|
||||
%78 : Double(4096!, 4096!) = ^Transpose(0, 1)(%14), uses = [[%79.i2]];
|
||||
%80 : Double(10, 4096), %81 : Handle = ^Addmm(1, 1, False)(%15, %75, %78), uses = [[%82.i0], []];
|
||||
%83 : Double(10, 4096), %84 : Handle = ^Threshold(0, 0, True)(%80), uses = [[%87.i1], []];
|
||||
%86 : Double(4096!, 1000!) = ^Transpose(0, 1)(%16), uses = [[%87.i2]];
|
||||
%88 : Double(10, 1000), %89 : Handle = ^Addmm(1, 1, False)(%17, %83, %86), uses = [[%0.i0], []];
|
||||
return (%88);
|
||||
%22 : Double(10, 64, 55, 55) = threshold[threshold={0}, value={0}, inplace=1](%19), uses = [[%23.i0]];
|
||||
%24 : Double(10, 64, 27, 27), %25 : Long(10, 64, 27, 27) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), uses = [[%26.i0], []];
|
||||
%27 : Double(10, 192, 27, 27), %28 : Handle = CppOp[ConvForward](%24, %4, %5), uses = [[%29.i0], []];
|
||||
%30 : Double(10, 192, 27, 27) = threshold[threshold={0}, value={0}, inplace=1](%27), uses = [[%31.i0]];
|
||||
%32 : Double(10, 192, 13, 13), %33 : Long(10, 192, 13, 13) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), uses = [[%34.i0], []];
|
||||
%35 : Double(10, 384, 13, 13), %36 : Handle = CppOp[ConvForward](%32, %6, %7), uses = [[%37.i0], []];
|
||||
%38 : Double(10, 384, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%35), uses = [[%39.i0]];
|
||||
%40 : Double(10, 256, 13, 13), %41 : Handle = CppOp[ConvForward](%38, %8, %9), uses = [[%42.i0], []];
|
||||
%43 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%40), uses = [[%44.i0]];
|
||||
%45 : Double(10, 256, 13, 13), %46 : Handle = CppOp[ConvForward](%43, %10, %11), uses = [[%47.i0], []];
|
||||
%48 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%45), uses = [[%49.i0]];
|
||||
%50 : Double(10, 256, 6, 6), %51 : Long(10, 256, 6, 6) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%48), uses = [[%52.i0], []];
|
||||
%53 : Double(10, 9216) = view[size=[10, 9216]](%50), uses = [[%54.i0]];
|
||||
%55 : Double(10, 9216), %56 : Handle = ^Dropout(0.5, True, False)(%53), uses = [[%61.i1], []];
|
||||
%58 : Double(9216!, 4096!) = t(%12), uses = [[%61.i2]];
|
||||
%60 : Double(10!, 4096) = expand[size=[10, 4096]](%13), uses = [[%61.i0]];
|
||||
%62 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%60, %55, %58), uses = [[%63.i0]];
|
||||
%64 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%62), uses = [[%65.i0]];
|
||||
%66 : Double(10, 4096), %67 : Handle = ^Dropout(0.5, True, False)(%64), uses = [[%72.i1], []];
|
||||
%69 : Double(4096!, 4096!) = t(%14), uses = [[%72.i2]];
|
||||
%71 : Double(10!, 4096) = expand[size=[10, 4096]](%15), uses = [[%72.i0]];
|
||||
%73 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%71, %66, %69), uses = [[%74.i0]];
|
||||
%75 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%73), uses = [[%80.i1]];
|
||||
%77 : Double(4096!, 1000!) = t(%16), uses = [[%80.i2]];
|
||||
%79 : Double(10!, 1000) = expand[size=[10, 1000]](%17), uses = [[%80.i0]];
|
||||
%81 : Double(10, 1000) = addmm[beta={1}, alpha={1}](%79, %75, %77), uses = [[%0.i0]];
|
||||
return (%81);
|
||||
}
|
||||
|
@ -3,6 +3,6 @@ graph(%1 : Double(10, 10)
|
||||
%4 : Double(10, 10!)) {
|
||||
%3 : Double(10, 10) = ^MyFn()(%1), uses = [[%0.i0, %5.i0]];
|
||||
---------------- stage 1 ----------------
|
||||
%6 : Double(10, 10) = ^Mul()(%3, %4), uses = [[%0.i1]];
|
||||
%6 : Double(10, 10) = mul(%3, %4), uses = [[%0.i1]];
|
||||
return (%3, %6);
|
||||
}
|
||||
|
@ -5,19 +5,19 @@ graph(%1 : Double(2, 2)
|
||||
-------- stage 2 --------
|
||||
%14 : Double(2, 2!)
|
||||
%15 : Double(2, 2)) {
|
||||
%4 : Double(2, 2) = ^MulConstant(2)(%2), uses = [[%5.i0, %10.i1, %16.i1]];
|
||||
%6 : Double(2, 2) = ^Mul()(%4, %1), uses = [[%0.i0]];
|
||||
%4 : Double(2, 2) = mul[other={2}](%2), uses = [[%5.i0, %10.i1, %16.i1]];
|
||||
%6 : Double(2, 2) = mul(%4, %1), uses = [[%0.i0]];
|
||||
---------------- stage 1 ----------------
|
||||
%9 : Double(2, 2) = ^Mul()(%7, %1), uses = [[%12.i0]];
|
||||
%11 : Double(2, 2) = ^Mul()(%7, %4), uses = [[%0.i1]];
|
||||
%13 : Double(2, 2) = ^MulConstant(2)(%9), uses = [[%0.i2]];
|
||||
%9 : Double(2, 2) = mul(%7, %1), uses = [[%12.i0]];
|
||||
%11 : Double(2, 2) = mul(%7, %4), uses = [[%0.i1]];
|
||||
%13 : Double(2, 2) = mul[other={2}](%9), uses = [[%0.i2]];
|
||||
---------------- stage 2 ----------------
|
||||
%17 : Double(2, 2) = ^Mul()(%14, %4), uses = [[%28.i0]];
|
||||
%19 : Double(2, 2) = ^Mul()(%14, %7), uses = [[%22.i0]];
|
||||
%21 : Double(2, 2) = ^MulConstant(2)(%15), uses = [[%24.i0, %26.i0]];
|
||||
%23 : Double(2, 2) = ^MulConstant(2)(%19), uses = [[%0.i5]];
|
||||
%25 : Double(2, 2) = ^Mul()(%21, %1), uses = [[%28.i1]];
|
||||
%27 : Double(2, 2) = ^Mul()(%21, %7), uses = [[%0.i4]];
|
||||
%17 : Double(2, 2) = mul(%14, %4), uses = [[%28.i0]];
|
||||
%19 : Double(2, 2) = mul(%14, %7), uses = [[%22.i0]];
|
||||
%21 : Double(2, 2) = mul[other={2}](%15), uses = [[%24.i0, %26.i0]];
|
||||
%23 : Double(2, 2) = mul[other={2}](%19), uses = [[%0.i5]];
|
||||
%25 : Double(2, 2) = mul(%21, %1), uses = [[%28.i1]];
|
||||
%27 : Double(2, 2) = mul(%21, %7), uses = [[%0.i4]];
|
||||
%29 : Double(2, 2) = CppOp[N5torch8autograd3AddE](%17, %25), uses = [[%0.i3]];
|
||||
return (%6, %11, %13, %29, %27, %23);
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
graph(%1 : Double(3, 3)
|
||||
%2 : Double(3, 3)
|
||||
-------- stage 1 --------
|
||||
%6 : Double(3, 3)) {
|
||||
%4 : Double(3, 3), %5 : Handle = ^Cross()(%1, %2), uses = [[%0.i0], [%7.i1]];
|
||||
%5 : Double(3, 3)) {
|
||||
%4 : Double(3, 3) = cross[dim=-1](%1, %2), uses = [[%0.i0]];
|
||||
---------------- stage 1 ----------------
|
||||
%17 : Double(3, 3), %18 : Double(3, 3), %19 : Handle = CppOp[N5torch8autograd4EvalE](%6, %5), uses = [[%0.i1], [%0.i2], []];
|
||||
return (%4, %17, %18);
|
||||
%7 : Double(3, 3) = cross[dim=-1](%2, %5), uses = [[%0.i1]];
|
||||
%9 : Double(3, 3) = cross[dim=-1](%5, %1), uses = [[%0.i2]];
|
||||
return (%4, %7, %9);
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
graph(%1 : Double(2)
|
||||
%2 : Double(2)) {
|
||||
%3 : Double(2) = Add(%1, %2), uses = [%5.i0, %5.i1, %7.i1];
|
||||
%5 : Double(2) = Mul(%3, %3), uses = [%7.i0];
|
||||
%7 : Double(2) = Mul(%5, %3), uses = [%8.i0, %16.i0];
|
||||
%8 : Double(2) = Tanh(%7), uses = [%10.i0, %10.i1];
|
||||
%10 : Double(2) = Add(%8, %8), uses = [%16.i1];
|
||||
%16 : Double(2) = Add(%7, %10), uses = [%0.i0];
|
||||
return (%16);
|
||||
%4 : Double(2) = add[alpha={1}](%1, %2), uses = [[%7.i0, %7.i1, %11.i1]];
|
||||
%8 : Double(2) = mul(%4, %4), uses = [[%11.i0]];
|
||||
%12 : Double(2) = mul(%8, %4), uses = [[%13.i0, %29.i0]];
|
||||
%14 : Double(2) = tanh(%12), uses = [[%17.i0, %17.i1]];
|
||||
%18 : Double(2) = add[alpha={1}](%14, %14), uses = [[%29.i1]];
|
||||
%30 : Double(2) = add[alpha={1}](%12, %18), uses = [[%0.i0]];
|
||||
return (%30);
|
||||
}
|
||||
|
@ -1,7 +0,0 @@
|
||||
graph(%1 : Double(4, 4)
|
||||
%2 : Double(4, 4)) {
|
||||
%3 : Double(4, 4) = Add(%1, %2), uses = [%4.i0];
|
||||
%5 : Double(4!, 2), %6 : Double(4!, 2) = Split[split=[2, 2], axis=1](%3), uses = [[%7.i0], [%7.i1]];
|
||||
%7 : Double(4, 2) = Mul(%5, %6), uses = [%0.i0];
|
||||
return (%7);
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
graph(%1 : Double(1)) {
|
||||
%3 : Double(1), %4 : Handle = ^Clone()(%1), uses = [[%5.i0], []];
|
||||
%6 : Double(1) = ^AddConstant(2, True)(%3), uses = [[%7.i0]];
|
||||
%8 : Double(1) = ^AddConstant(3, True)(%6), uses = [[%0.i0]];
|
||||
return (%8);
|
||||
%3 : Double(1) = clone(%1), uses = [[%4.i0]];
|
||||
%5 : Double(1) = add[other={2}, alpha={1}](%3), uses = [[%6.i0]];
|
||||
%7 : Double(1) = add[other={3}, alpha={1}](%5), uses = [[%0.i0]];
|
||||
return (%7);
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
graph(%1 : UNKNOWN_TYPE
|
||||
%2 : UNKNOWN_TYPE) {
|
||||
%4 : Double(1) = Add[note=from_pyop, some_value=1](%1, %2), uses = [[%5.i1]];
|
||||
%6 : Double(1) = Mul[note=from_pyop, some_value=0](%1, %4), uses = [[%7.i0]];
|
||||
%8 : Double(1) = Tanh[note=from_pyop, some_value=0](%6), uses = [[%9.i0]];
|
||||
%10 : Double(1) = Sigmoid[note=from_pyop, some_value=0](%8), uses = [[%0.i0]];
|
||||
%4 : Double(1) = add[alpha={1}](%1, %2), uses = [[%5.i1]];
|
||||
%6 : Double(1) = mul(%1, %4), uses = [[%7.i0]];
|
||||
%8 : Double(1) = tanh(%6), uses = [[%9.i0]];
|
||||
%10 : Double(1) = sigmoid(%8), uses = [[%0.i0]];
|
||||
%11 : UNKNOWN_TYPE = TensorTest[a= 1 1 1 1 [ CPUDoubleTensor{2,2} ]](), uses = [];
|
||||
return (%10);
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
graph(%1 : Double(1)
|
||||
%2 : Double(1)) {
|
||||
%3 : Double(1) = Add(%1, %2), uses = [%4.i1];
|
||||
%4 : Double(1) = Mul(%1, %3), uses = [%5.i0];
|
||||
%5 : Double(1) = Tanh(%4), uses = [%6.i0];
|
||||
%6 : Double(1) = Sigmoid(%5), uses = [%0.i0];
|
||||
return (%6);
|
||||
%4 : Double(1) = add[alpha={1}](%1, %2), uses = [[%5.i1]];
|
||||
%6 : Double(1) = mul(%1, %4), uses = [[%7.i0]];
|
||||
%8 : Double(1) = tanh(%6), uses = [[%9.i0]];
|
||||
%10 : Double(1) = sigmoid(%8), uses = [[%0.i0]];
|
||||
return (%10);
|
||||
}
|
||||
|
@ -33,7 +33,6 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
||||
return hy, cy
|
||||
|
||||
|
||||
@unittest.skip("JIT tests temporarily broken")
|
||||
class TestJit(TestCase):
|
||||
maxDiff = None
|
||||
|
||||
@ -45,13 +44,10 @@ class TestJit(TestCase):
|
||||
return torch.sigmoid(torch.tanh(x * (x + y)))
|
||||
|
||||
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
|
||||
|
||||
torch._C._jit_pass_lint(trace)
|
||||
torch._C._jit_pass_onnx(trace)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
|
||||
self.assertExpected(str(trace))
|
||||
|
||||
@unittest.skip("Fuser is broken")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
|
||||
def test_lstm_fusion(self):
|
||||
input = Variable(torch.randn(3, 10).cuda())
|
||||
@ -61,12 +57,11 @@ class TestJit(TestCase):
|
||||
|
||||
trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
|
||||
torch._C._jit_pass_lint(trace)
|
||||
torch._C._jit_pass_onnx(trace)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
torch._C._jit_pass_fuse(trace)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace))
|
||||
|
||||
@unittest.skip("Fuser is broken")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
|
||||
def test_run_lstm_fusion(self):
|
||||
input = Variable(torch.randn(3, 10).cuda())
|
||||
@ -80,6 +75,7 @@ class TestJit(TestCase):
|
||||
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters(), _assert_compiled=True)
|
||||
self.assertEqual(z, z2)
|
||||
|
||||
@unittest.skip("Fuser is broken")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
|
||||
def test_fusion_distribute(self):
|
||||
def f(x, y):
|
||||
@ -90,9 +86,6 @@ class TestJit(TestCase):
|
||||
trace, _ = torch.jit.trace(f, (x, y), nderivs=0)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace), 'raw')
|
||||
torch._C._jit_pass_onnx(trace)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace), 'onnx')
|
||||
torch._C._jit_pass_fuse(trace)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
self.assertExpected(str(trace))
|
||||
@ -107,8 +100,6 @@ class TestJit(TestCase):
|
||||
z = (x + y) * (x + y) * (x + y) + t
|
||||
torch._C._tracer_exit((z,))
|
||||
torch._C._jit_pass_lint(trace)
|
||||
torch._C._jit_pass_onnx(trace)
|
||||
torch._C._jit_pass_lint(trace)
|
||||
torch._C._jit_pass_cse(trace)
|
||||
|
||||
self.assertExpected(str(trace))
|
||||
@ -280,19 +271,38 @@ class TestJit(TestCase):
|
||||
self.assertExpected(str(trace))
|
||||
|
||||
def test_inplace_flags(self):
|
||||
class InplaceFn(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.mark_dirty(x)
|
||||
return x.add_(1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, go):
|
||||
return go
|
||||
|
||||
class RegularFn(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.add(1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, go):
|
||||
return go
|
||||
|
||||
x = Variable(torch.Tensor([0]), requires_grad=True)
|
||||
trace = torch._C._tracer_enter((x,), 0)
|
||||
y = x + 2
|
||||
y.add_(2)
|
||||
y.mul_(4)
|
||||
y = y * 2
|
||||
y = RegularFn.apply(x)
|
||||
y = InplaceFn.apply(y)
|
||||
y = InplaceFn.apply(y)
|
||||
y = RegularFn.apply(y)
|
||||
torch._C._tracer_exit((y,))
|
||||
ops = [n for n in trace.graph().nodes() if n.kind() != 'Select']
|
||||
for op in ops:
|
||||
self.assertTrue(op.hasAttribute('__inplace'))
|
||||
self.assertTrue(op.hasAttribute('inplace'))
|
||||
inplace_flags = [False, True, True, False]
|
||||
for op, is_inplace in zip(ops, inplace_flags):
|
||||
self.assertEqual(op.i('__inplace'), is_inplace)
|
||||
self.assertEqual(op.i('inplace'), is_inplace)
|
||||
|
||||
def test_inplace_check(self):
|
||||
class MyInplaceFn(Function):
|
||||
@ -548,7 +558,6 @@ class TestJit(TestCase):
|
||||
assert(n_.i("some_value") == len(node.scalar_args()))
|
||||
else:
|
||||
n_ = g2.createClone(node, lambda x: g_to_g2[x])
|
||||
assert(n_.kindOf("Offset") == "i")
|
||||
|
||||
g_to_g2[node] = g2.appendNode(n_)
|
||||
|
||||
|
@ -426,15 +426,12 @@ class Variable(_C._VariableBase):
|
||||
def bernoulli(self):
|
||||
return Bernoulli.apply(self)
|
||||
|
||||
def __add__(self, other):
|
||||
return self.add(other)
|
||||
__radd__ = __add__
|
||||
__radd__ = __add__ = _C._VariableBase.add
|
||||
|
||||
def __iadd__(self, other):
|
||||
return self.add_(other)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.sub(other)
|
||||
__sub__ = _C._VariableBase.sub
|
||||
|
||||
def __isub__(self, other):
|
||||
return self.sub_(other)
|
||||
@ -442,9 +439,7 @@ class Variable(_C._VariableBase):
|
||||
def __rsub__(self, other):
|
||||
return -self + other
|
||||
|
||||
def __mul__(self, other):
|
||||
return self.mul(other)
|
||||
__rmul__ = __mul__
|
||||
__rmul__ = __mul__ = _C._VariableBase.mul
|
||||
|
||||
def __imul__(self, other):
|
||||
return self.mul_(other)
|
||||
@ -454,9 +449,7 @@ class Variable(_C._VariableBase):
|
||||
return NotImplemented
|
||||
return self.matmul(other)
|
||||
|
||||
def __div__(self, other):
|
||||
return self.div(other)
|
||||
__truediv__ = __div__
|
||||
__truediv__ = __div__ = _C._VariableBase.div
|
||||
|
||||
def __rdiv__(self, other):
|
||||
return self.reciprocal() * other
|
||||
@ -465,8 +458,7 @@ class Variable(_C._VariableBase):
|
||||
def __idiv__(self, other):
|
||||
return self.div_(other)
|
||||
|
||||
def __pow__(self, other):
|
||||
return self.pow(other)
|
||||
__pow__ = _C._VariableBase.pow
|
||||
|
||||
def __ipow__(self, other):
|
||||
raise NotImplementedError("in-place pow not implemented")
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "torch/csrc/autograd/python_engine.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/python_function.h"
|
||||
#include "torch/csrc/jit/generated/aten_dispatch.h"
|
||||
#ifdef WITH_CUDA
|
||||
#include "torch/csrc/jit/fusion_compiler.h"
|
||||
#endif
|
||||
@ -115,20 +116,28 @@ struct EmitNull : public Function {
|
||||
};
|
||||
};
|
||||
|
||||
// A hack that will let us implement some of the ops we care
|
||||
// about before the major Python -> C++ Function migration
|
||||
struct LambdaFunction : public Function {
|
||||
LambdaFunction(const jit::TensorOp& op)
|
||||
: LambdaFunction(op.num_inputs, op.op) {
|
||||
this->name_ = op.name;
|
||||
}
|
||||
|
||||
LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn)
|
||||
: fn(fn) {
|
||||
: fn_(fn) {
|
||||
this->is_executable = true;
|
||||
this->num_inputs = num_inputs;
|
||||
}
|
||||
|
||||
virtual variable_list apply(const variable_list& inputs) {
|
||||
return fn(inputs);
|
||||
virtual std::string name() override {
|
||||
return name_.size() == 0 ? "LambdaFunction" : name_;
|
||||
}
|
||||
|
||||
std::function<variable_list(const variable_list&)> fn;
|
||||
virtual variable_list apply(const variable_list& inputs) override {
|
||||
return fn_(inputs);
|
||||
}
|
||||
|
||||
std::string name_;
|
||||
std::function<variable_list(const variable_list&)> fn_;
|
||||
};
|
||||
|
||||
// Wraps a PythonOp and dispatches calls to Functions implemented in Python
|
||||
@ -583,7 +592,7 @@ struct StageClosure {
|
||||
IR_ELSEIF(Concat)
|
||||
return std::make_shared<torch::autograd::Cat>(value->i(kaxis));
|
||||
IR_ELSE()
|
||||
throw std::runtime_error(std::string("unrecognized NodeKind: ") + symbolToString(node->kind()));
|
||||
return std::make_shared<LambdaFunction>(getTensorOp(node));
|
||||
IR_END()
|
||||
}
|
||||
|
||||
@ -671,7 +680,7 @@ struct StageClosure {
|
||||
// Roots for a call to the engine. The list contains function in this order:
|
||||
// [ apply input roots | prev stage input roots | constant factory ]
|
||||
function_list roots;
|
||||
std::vector<VariableFlags> var_flags;
|
||||
std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags;
|
||||
|
||||
// Output node
|
||||
std::shared_ptr<Function> output;
|
||||
@ -703,15 +712,14 @@ struct MultiStageClosure {
|
||||
};
|
||||
|
||||
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc)
|
||||
: AutogradClosure(desc, 0, {}) {}
|
||||
: AutogradClosure(desc, 0) {}
|
||||
|
||||
// TODO: there's a lot processing involved in creating a new AutogradClosure instance,
|
||||
// so it might be worth to keep a pool of unused instances (or at least their attrs)
|
||||
// for all stages. We can't save saved_vars and saved_handles, but all callbacks
|
||||
// can be made reusable.
|
||||
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags &&f)
|
||||
: Function(std::move(f))
|
||||
, desc(desc)
|
||||
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage)
|
||||
: desc(desc)
|
||||
, stage(stage) {
|
||||
auto & stage_desc = desc->stages[stage];
|
||||
|
||||
@ -777,10 +785,10 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
|
||||
|
||||
// Validate inputs
|
||||
auto num_inputs = inputs.size();
|
||||
if (num_inputs != stage_closure.var_flags.size())
|
||||
if (num_inputs != stage_closure.var_flags.first.size())
|
||||
throw std::runtime_error("AutogradClosure received an incorrect number of inputs");
|
||||
for (std::size_t i = 0; i < num_inputs; ++i) {
|
||||
auto & flags = stage_closure.var_flags[i];
|
||||
auto & flags = stage_closure.var_flags.first[i];
|
||||
if (!flags.verify(inputs[i]))
|
||||
throw std::runtime_error("AutogradClosure received inputs with different flags");
|
||||
}
|
||||
@ -797,16 +805,15 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
|
||||
auto& engine = python::PythonEngine::getDefaultEngine();
|
||||
engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks);
|
||||
|
||||
// See Note [Null-edge pruning]
|
||||
auto relevant_inputs = filter(inputs, [](const Variable& var) { return var.defined() && var.requires_grad(); });
|
||||
auto result = wrap_outputs(relevant_inputs, std::move(outputs), [this](FunctionFlags f) -> std::shared_ptr<Function> {
|
||||
// Create the backward function lazily
|
||||
auto make_grad_fn = [this]() -> std::shared_ptr<Function> {
|
||||
if (this->stage == this->desc->stages.size() - 1) {
|
||||
std::string msg = "JIT closure compiled only for ";
|
||||
msg += std::to_string(this->stage);
|
||||
msg += " backwards";
|
||||
return std::make_shared<Error>(std::move(msg), std::move(f));
|
||||
return std::make_shared<Error>(std::move(msg));
|
||||
}
|
||||
auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1, std::move(f)));
|
||||
auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1));
|
||||
// TODO: don't make a full copy of saved_* - copy only the things that bw needs
|
||||
bw_fn->saved_vars = this->saved_vars;
|
||||
bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()),
|
||||
@ -824,7 +831,33 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
|
||||
// was run, so it must have been executable).
|
||||
bw_fn->is_executable = true;
|
||||
return bw_fn;
|
||||
});
|
||||
};
|
||||
|
||||
// See Note [Null-edge pruning]
|
||||
variable_list result;
|
||||
auto num_outputs = outputs.size();
|
||||
std::shared_ptr<Function> grad_fn;
|
||||
JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size());
|
||||
for (std::size_t i = 0; i < num_outputs; ++i) {
|
||||
auto & flags = stage_closure.var_flags.second[i];
|
||||
if (flags.requires_grad) {
|
||||
if (!grad_fn) grad_fn = make_grad_fn();
|
||||
result.push_back(make_variable(outputs[i], grad_fn));
|
||||
} else {
|
||||
result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile));
|
||||
}
|
||||
}
|
||||
|
||||
// If we created grad_fn for any of the outputs, we also need to fill in next_functions
|
||||
if (grad_fn) {
|
||||
for (auto & input : inputs) {
|
||||
if (!input.requires_grad()) continue;
|
||||
grad_fn->next_functions.emplace_back(
|
||||
input.grad_fn() ? input.grad_fn() : input.grad_accumulator(),
|
||||
input.output_nr());
|
||||
}
|
||||
}
|
||||
|
||||
captured_vars.clear();
|
||||
captured_handles.clear();
|
||||
outputs.clear();
|
||||
|
@ -28,7 +28,7 @@ struct AutogradClosure : public Function {
|
||||
virtual variable_list apply(const variable_list& inputs) override;
|
||||
|
||||
private:
|
||||
AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags&& f);
|
||||
AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage);
|
||||
|
||||
variable_list rewrapInputs(const variable_list& inputs);
|
||||
|
||||
|
@ -711,7 +711,7 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
|
||||
sel->inferTypeFrom(output.data());
|
||||
tracer::setValueTrace(tracing_state, output, sel);
|
||||
}
|
||||
this_expr->i_(k__inplace, is_inplace);
|
||||
this_expr->i_(kinplace, is_inplace);
|
||||
|
||||
// See definition in function.cpp.
|
||||
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
|
||||
|
@ -64,7 +64,7 @@ _(perm) \
|
||||
_(shape) \
|
||||
_(axes) \
|
||||
_(group) \
|
||||
_(__inplace)
|
||||
_(inplace)
|
||||
|
||||
enum BuiltinSymbol {
|
||||
#define DEFINE_SYMBOL(s) \
|
||||
|
@ -177,7 +177,17 @@ void printAttributes(std::ostream & out, Node * n) {
|
||||
case AttributeKind::t:
|
||||
{
|
||||
at::Tensor t = n->t(name);
|
||||
if (t.numel() <= max_tensor_display_size) {
|
||||
// 1-elem tensors are usually boxed scalars, so print them like it
|
||||
if (t.numel() == 1) {
|
||||
auto scalar = at::Scalar(t.view({})).local();
|
||||
out << "{";
|
||||
if (scalar.isFloatingPoint()) {
|
||||
out << scalar.toDouble();
|
||||
} else {
|
||||
out << scalar.toLong();
|
||||
}
|
||||
out << "}";
|
||||
} else if (t.numel() <= max_tensor_display_size) {
|
||||
// TODO: This is awful code. Also it doesn't work on Windows.
|
||||
std::ostringstream tensor_ss;
|
||||
tensor_ss << t;
|
||||
|
@ -8,6 +8,16 @@
|
||||
|
||||
namespace torch { namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
|
||||
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
|
||||
}
|
||||
|
||||
bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
|
||||
if (lhs.size() != rhs.size()) return false;
|
||||
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
|
||||
};
|
||||
|
||||
|
||||
// Check whether two nodes have the same attributes in CSE.
|
||||
@ -24,6 +34,8 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
|
||||
|
||||
auto lnames = lhs->attributeNames();
|
||||
auto rnames = rhs->attributeNames();
|
||||
std::sort(lnames.begin(), lnames.end());
|
||||
std::sort(rnames.begin(), rnames.end());
|
||||
if (lnames != rnames) return false;
|
||||
|
||||
for (auto name : lnames) {
|
||||
@ -40,8 +52,13 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
|
||||
COMPARE_ATTRIBUTEVALUE(is)
|
||||
COMPARE_ATTRIBUTEVALUE(s)
|
||||
COMPARE_ATTRIBUTEVALUE(ss)
|
||||
case AttributeKind::t:
|
||||
if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
|
||||
break;
|
||||
case AttributeKind::ts:
|
||||
if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
|
||||
default:
|
||||
// NB: Comparison of nodes with tensor(s) or graph(s) will return false.
|
||||
// NB: Comparison of nodes with graph(s) will return false.
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -92,6 +109,8 @@ struct EqualNodeCSE {
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// The function implements common subexpression elimination.
|
||||
// Since the nodes are visited in topological order, one pass is enough.
|
||||
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
|
||||
|
@ -56,7 +56,7 @@ struct TraceEval : autograd::Eval {
|
||||
setValueTrace(tracing_state, input, input_node);
|
||||
input_node->inferTypeFrom(input.data());
|
||||
}
|
||||
tracing_state->var_flags.at(graph->stage()) = detail::getVarFlags(inputs);
|
||||
tracing_state->var_flags.at(graph->stage()).first = detail::getVarFlags(inputs);
|
||||
}
|
||||
|
||||
void exitTrace(const variable_list& inputs, const variable_list& outputs) {
|
||||
|
@ -200,7 +200,7 @@ inline std::shared_ptr<TracingState> enter(std::vector<TraceInput>&& trace_input
|
||||
}
|
||||
}
|
||||
// TODO: this might not work with the way we handle buffers
|
||||
state->var_flags[0] = detail::getVarFlags(inputs);
|
||||
state->var_flags[0].first = detail::getVarFlags(inputs);
|
||||
state->active = true;
|
||||
state->inputs = inputs;
|
||||
return state;
|
||||
@ -214,6 +214,7 @@ inline void _exit(const std::shared_ptr<TracingState>& state, const variable_lis
|
||||
state->graph->registerOutput(getValueTrace(state, output, true));
|
||||
}
|
||||
state->active = false;
|
||||
state->var_flags[state->graph->stage()].second = detail::getVarFlags(outputs);
|
||||
}
|
||||
|
||||
// Marks a backwards subgraph that should be traced as the next stage.
|
||||
|
@ -64,7 +64,8 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
|
||||
// TODO: Perhaps, turn this into an owning reference. The buffers
|
||||
// are persistent, so this won't lead to a leak.
|
||||
std::unordered_map<void*, Node*> buffer_map;
|
||||
std::vector<std::vector<VariableFlags>> var_flags;
|
||||
// A pair of (input_flags, output_flags) for each stage
|
||||
std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>> var_flags;
|
||||
std::vector<function_list> output_edges;
|
||||
|
||||
std::mutex mutex;
|
||||
|
@ -469,7 +469,6 @@ class TraceForKey(object):
|
||||
|
||||
# It's important to always run DCE, because backward can create a lot of unnecessary nodes
|
||||
_run_pass(torch._C._jit_pass_dce, complete_trace)
|
||||
_run_pass(torch._C._jit_pass_onnx, complete_trace)
|
||||
_run_pass(_passes._check_inplace, complete_trace)
|
||||
if self.optimize:
|
||||
_run_pass(torch._C._jit_pass_fuse, complete_trace)
|
||||
|
@ -7,5 +7,5 @@ def _check_inplace(trace):
|
||||
graph = trace.graph()
|
||||
for node in graph.nodes():
|
||||
if node.kind() == 'PythonOp':
|
||||
if node.i('__inplace'):
|
||||
if node.i('inplace'):
|
||||
raise RuntimeError("inplace {} not supported in the JIT".format(node.pyname()))
|
||||
|
Reference in New Issue
Block a user