Re-enable and fix most JIT tests

This commit is contained in:
Adam Paszke
2017-10-25 21:26:13 +02:00
committed by Soumith Chintala
parent 61afb0d519
commit fa0f3cf98a
22 changed files with 190 additions and 129 deletions

View File

@ -16,28 +16,31 @@ graph(%1 : Double(10, 3, 224, 224)
%16 : Double(1000, 4096)
%17 : Double(1000)) {
%19 : Double(10, 64, 55, 55), %20 : Handle = CppOp[ConvForward](%1, %2, %3), uses = [[%21.i0], []];
%22 : Double(10, 64, 55, 55), %23 : Handle = ^Threshold(0, 0, True)(%19), uses = [[%24.i0], []];
%25 : Double(10, 64, 27, 27), %26 : Long(10, 64, 27, 27), %27 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%22), uses = [[%28.i0], [], []];
%29 : Double(10, 192, 27, 27), %30 : Handle = CppOp[ConvForward](%25, %4, %5), uses = [[%31.i0], []];
%32 : Double(10, 192, 27, 27), %33 : Handle = ^Threshold(0, 0, True)(%29), uses = [[%34.i0], []];
%35 : Double(10, 192, 13, 13), %36 : Long(10, 192, 13, 13), %37 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%32), uses = [[%38.i0], [], []];
%39 : Double(10, 384, 13, 13), %40 : Handle = CppOp[ConvForward](%35, %6, %7), uses = [[%41.i0], []];
%42 : Double(10, 384, 13, 13), %43 : Handle = ^Threshold(0, 0, True)(%39), uses = [[%44.i0], []];
%45 : Double(10, 256, 13, 13), %46 : Handle = CppOp[ConvForward](%42, %8, %9), uses = [[%47.i0], []];
%48 : Double(10, 256, 13, 13), %49 : Handle = ^Threshold(0, 0, True)(%45), uses = [[%50.i0], []];
%51 : Double(10, 256, 13, 13), %52 : Handle = CppOp[ConvForward](%48, %10, %11), uses = [[%53.i0], []];
%54 : Double(10, 256, 13, 13), %55 : Handle = ^Threshold(0, 0, True)(%51), uses = [[%56.i0], []];
%57 : Double(10, 256, 6, 6), %58 : Long(10, 256, 6, 6), %59 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%54), uses = [[%60.i0], [], []];
%61 : Double(10, 9216), %62 : Handle = ^View((10, 9216))(%57), uses = [[%63.i0], []];
%64 : Double(10, 9216), %65 : Handle = ^Dropout(0.5, True, False)(%61), uses = [[%68.i1], []];
%67 : Double(9216!, 4096!) = ^Transpose(0, 1)(%12), uses = [[%68.i2]];
%69 : Double(10, 4096), %70 : Handle = ^Addmm(1, 1, False)(%13, %64, %67), uses = [[%71.i0], []];
%72 : Double(10, 4096), %73 : Handle = ^Threshold(0, 0, True)(%69), uses = [[%74.i0], []];
%75 : Double(10, 4096), %76 : Handle = ^Dropout(0.5, True, False)(%72), uses = [[%79.i1], []];
%78 : Double(4096!, 4096!) = ^Transpose(0, 1)(%14), uses = [[%79.i2]];
%80 : Double(10, 4096), %81 : Handle = ^Addmm(1, 1, False)(%15, %75, %78), uses = [[%82.i0], []];
%83 : Double(10, 4096), %84 : Handle = ^Threshold(0, 0, True)(%80), uses = [[%87.i1], []];
%86 : Double(4096!, 1000!) = ^Transpose(0, 1)(%16), uses = [[%87.i2]];
%88 : Double(10, 1000), %89 : Handle = ^Addmm(1, 1, False)(%17, %83, %86), uses = [[%0.i0], []];
return (%88);
%22 : Double(10, 64, 55, 55) = threshold[threshold={0}, value={0}, inplace=1](%19), uses = [[%23.i0]];
%24 : Double(10, 64, 27, 27), %25 : Long(10, 64, 27, 27) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), uses = [[%26.i0], []];
%27 : Double(10, 192, 27, 27), %28 : Handle = CppOp[ConvForward](%24, %4, %5), uses = [[%29.i0], []];
%30 : Double(10, 192, 27, 27) = threshold[threshold={0}, value={0}, inplace=1](%27), uses = [[%31.i0]];
%32 : Double(10, 192, 13, 13), %33 : Long(10, 192, 13, 13) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), uses = [[%34.i0], []];
%35 : Double(10, 384, 13, 13), %36 : Handle = CppOp[ConvForward](%32, %6, %7), uses = [[%37.i0], []];
%38 : Double(10, 384, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%35), uses = [[%39.i0]];
%40 : Double(10, 256, 13, 13), %41 : Handle = CppOp[ConvForward](%38, %8, %9), uses = [[%42.i0], []];
%43 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%40), uses = [[%44.i0]];
%45 : Double(10, 256, 13, 13), %46 : Handle = CppOp[ConvForward](%43, %10, %11), uses = [[%47.i0], []];
%48 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%45), uses = [[%49.i0]];
%50 : Double(10, 256, 6, 6), %51 : Long(10, 256, 6, 6) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%48), uses = [[%52.i0], []];
%53 : Double(10, 9216) = view[size=[10, 9216]](%50), uses = [[%54.i0]];
%55 : Double(10, 9216), %56 : Handle = ^Dropout(0.5, True, False)(%53), uses = [[%61.i1], []];
%58 : Double(9216!, 4096!) = t(%12), uses = [[%61.i2]];
%60 : Double(10!, 4096) = expand[size=[10, 4096]](%13), uses = [[%61.i0]];
%62 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%60, %55, %58), uses = [[%63.i0]];
%64 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%62), uses = [[%65.i0]];
%66 : Double(10, 4096), %67 : Handle = ^Dropout(0.5, True, False)(%64), uses = [[%72.i1], []];
%69 : Double(4096!, 4096!) = t(%14), uses = [[%72.i2]];
%71 : Double(10!, 4096) = expand[size=[10, 4096]](%15), uses = [[%72.i0]];
%73 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%71, %66, %69), uses = [[%74.i0]];
%75 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%73), uses = [[%80.i1]];
%77 : Double(4096!, 1000!) = t(%16), uses = [[%80.i2]];
%79 : Double(10!, 1000) = expand[size=[10, 1000]](%17), uses = [[%80.i0]];
%81 : Double(10, 1000) = addmm[beta={1}, alpha={1}](%79, %75, %77), uses = [[%0.i0]];
return (%81);
}

View File

@ -3,6 +3,6 @@ graph(%1 : Double(10, 10)
%4 : Double(10, 10!)) {
%3 : Double(10, 10) = ^MyFn()(%1), uses = [[%0.i0, %5.i0]];
---------------- stage 1 ----------------
%6 : Double(10, 10) = ^Mul()(%3, %4), uses = [[%0.i1]];
%6 : Double(10, 10) = mul(%3, %4), uses = [[%0.i1]];
return (%3, %6);
}

View File

@ -5,19 +5,19 @@ graph(%1 : Double(2, 2)
-------- stage 2 --------
%14 : Double(2, 2!)
%15 : Double(2, 2)) {
%4 : Double(2, 2) = ^MulConstant(2)(%2), uses = [[%5.i0, %10.i1, %16.i1]];
%6 : Double(2, 2) = ^Mul()(%4, %1), uses = [[%0.i0]];
%4 : Double(2, 2) = mul[other={2}](%2), uses = [[%5.i0, %10.i1, %16.i1]];
%6 : Double(2, 2) = mul(%4, %1), uses = [[%0.i0]];
---------------- stage 1 ----------------
%9 : Double(2, 2) = ^Mul()(%7, %1), uses = [[%12.i0]];
%11 : Double(2, 2) = ^Mul()(%7, %4), uses = [[%0.i1]];
%13 : Double(2, 2) = ^MulConstant(2)(%9), uses = [[%0.i2]];
%9 : Double(2, 2) = mul(%7, %1), uses = [[%12.i0]];
%11 : Double(2, 2) = mul(%7, %4), uses = [[%0.i1]];
%13 : Double(2, 2) = mul[other={2}](%9), uses = [[%0.i2]];
---------------- stage 2 ----------------
%17 : Double(2, 2) = ^Mul()(%14, %4), uses = [[%28.i0]];
%19 : Double(2, 2) = ^Mul()(%14, %7), uses = [[%22.i0]];
%21 : Double(2, 2) = ^MulConstant(2)(%15), uses = [[%24.i0, %26.i0]];
%23 : Double(2, 2) = ^MulConstant(2)(%19), uses = [[%0.i5]];
%25 : Double(2, 2) = ^Mul()(%21, %1), uses = [[%28.i1]];
%27 : Double(2, 2) = ^Mul()(%21, %7), uses = [[%0.i4]];
%17 : Double(2, 2) = mul(%14, %4), uses = [[%28.i0]];
%19 : Double(2, 2) = mul(%14, %7), uses = [[%22.i0]];
%21 : Double(2, 2) = mul[other={2}](%15), uses = [[%24.i0, %26.i0]];
%23 : Double(2, 2) = mul[other={2}](%19), uses = [[%0.i5]];
%25 : Double(2, 2) = mul(%21, %1), uses = [[%28.i1]];
%27 : Double(2, 2) = mul(%21, %7), uses = [[%0.i4]];
%29 : Double(2, 2) = CppOp[N5torch8autograd3AddE](%17, %25), uses = [[%0.i3]];
return (%6, %11, %13, %29, %27, %23);
}

View File

@ -1,9 +1,10 @@
graph(%1 : Double(3, 3)
%2 : Double(3, 3)
-------- stage 1 --------
%6 : Double(3, 3)) {
%4 : Double(3, 3), %5 : Handle = ^Cross()(%1, %2), uses = [[%0.i0], [%7.i1]];
%5 : Double(3, 3)) {
%4 : Double(3, 3) = cross[dim=-1](%1, %2), uses = [[%0.i0]];
---------------- stage 1 ----------------
%17 : Double(3, 3), %18 : Double(3, 3), %19 : Handle = CppOp[N5torch8autograd4EvalE](%6, %5), uses = [[%0.i1], [%0.i2], []];
return (%4, %17, %18);
%7 : Double(3, 3) = cross[dim=-1](%2, %5), uses = [[%0.i1]];
%9 : Double(3, 3) = cross[dim=-1](%5, %1), uses = [[%0.i2]];
return (%4, %7, %9);
}

View File

@ -1,10 +1,10 @@
graph(%1 : Double(2)
%2 : Double(2)) {
%3 : Double(2) = Add(%1, %2), uses = [%5.i0, %5.i1, %7.i1];
%5 : Double(2) = Mul(%3, %3), uses = [%7.i0];
%7 : Double(2) = Mul(%5, %3), uses = [%8.i0, %16.i0];
%8 : Double(2) = Tanh(%7), uses = [%10.i0, %10.i1];
%10 : Double(2) = Add(%8, %8), uses = [%16.i1];
%16 : Double(2) = Add(%7, %10), uses = [%0.i0];
return (%16);
%4 : Double(2) = add[alpha={1}](%1, %2), uses = [[%7.i0, %7.i1, %11.i1]];
%8 : Double(2) = mul(%4, %4), uses = [[%11.i0]];
%12 : Double(2) = mul(%8, %4), uses = [[%13.i0, %29.i0]];
%14 : Double(2) = tanh(%12), uses = [[%17.i0, %17.i1]];
%18 : Double(2) = add[alpha={1}](%14, %14), uses = [[%29.i1]];
%30 : Double(2) = add[alpha={1}](%12, %18), uses = [[%0.i0]];
return (%30);
}

View File

@ -1,7 +0,0 @@
graph(%1 : Double(4, 4)
%2 : Double(4, 4)) {
%3 : Double(4, 4) = Add(%1, %2), uses = [%4.i0];
%5 : Double(4!, 2), %6 : Double(4!, 2) = Split[split=[2, 2], axis=1](%3), uses = [[%7.i0], [%7.i1]];
%7 : Double(4, 2) = Mul(%5, %6), uses = [%0.i0];
return (%7);
}

View File

@ -1,6 +1,6 @@
graph(%1 : Double(1)) {
%3 : Double(1), %4 : Handle = ^Clone()(%1), uses = [[%5.i0], []];
%6 : Double(1) = ^AddConstant(2, True)(%3), uses = [[%7.i0]];
%8 : Double(1) = ^AddConstant(3, True)(%6), uses = [[%0.i0]];
return (%8);
%3 : Double(1) = clone(%1), uses = [[%4.i0]];
%5 : Double(1) = add[other={2}, alpha={1}](%3), uses = [[%6.i0]];
%7 : Double(1) = add[other={3}, alpha={1}](%5), uses = [[%0.i0]];
return (%7);
}

View File

@ -1,9 +1,9 @@
graph(%1 : UNKNOWN_TYPE
%2 : UNKNOWN_TYPE) {
%4 : Double(1) = Add[note=from_pyop, some_value=1](%1, %2), uses = [[%5.i1]];
%6 : Double(1) = Mul[note=from_pyop, some_value=0](%1, %4), uses = [[%7.i0]];
%8 : Double(1) = Tanh[note=from_pyop, some_value=0](%6), uses = [[%9.i0]];
%10 : Double(1) = Sigmoid[note=from_pyop, some_value=0](%8), uses = [[%0.i0]];
%4 : Double(1) = add[alpha={1}](%1, %2), uses = [[%5.i1]];
%6 : Double(1) = mul(%1, %4), uses = [[%7.i0]];
%8 : Double(1) = tanh(%6), uses = [[%9.i0]];
%10 : Double(1) = sigmoid(%8), uses = [[%0.i0]];
%11 : UNKNOWN_TYPE = TensorTest[a= 1 1 1 1 [ CPUDoubleTensor{2,2} ]](), uses = [];
return (%10);
}

View File

@ -1,8 +1,8 @@
graph(%1 : Double(1)
%2 : Double(1)) {
%3 : Double(1) = Add(%1, %2), uses = [%4.i1];
%4 : Double(1) = Mul(%1, %3), uses = [%5.i0];
%5 : Double(1) = Tanh(%4), uses = [%6.i0];
%6 : Double(1) = Sigmoid(%5), uses = [%0.i0];
return (%6);
%4 : Double(1) = add[alpha={1}](%1, %2), uses = [[%5.i1]];
%6 : Double(1) = mul(%1, %4), uses = [[%7.i0]];
%8 : Double(1) = tanh(%6), uses = [[%9.i0]];
%10 : Double(1) = sigmoid(%8), uses = [[%0.i0]];
return (%10);
}

View File

@ -33,7 +33,6 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
return hy, cy
@unittest.skip("JIT tests temporarily broken")
class TestJit(TestCase):
maxDiff = None
@ -45,13 +44,10 @@ class TestJit(TestCase):
return torch.sigmoid(torch.tanh(x * (x + y)))
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
@unittest.skip("Fuser is broken")
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_lstm_fusion(self):
input = Variable(torch.randn(3, 10).cuda())
@ -61,12 +57,11 @@ class TestJit(TestCase):
trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
@unittest.skip("Fuser is broken")
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_run_lstm_fusion(self):
input = Variable(torch.randn(3, 10).cuda())
@ -80,6 +75,7 @@ class TestJit(TestCase):
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters(), _assert_compiled=True)
self.assertEqual(z, z2)
@unittest.skip("Fuser is broken")
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_fusion_distribute(self):
def f(x, y):
@ -90,9 +86,6 @@ class TestJit(TestCase):
trace, _ = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace), 'raw')
torch._C._jit_pass_onnx(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace), 'onnx')
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
@ -107,8 +100,6 @@ class TestJit(TestCase):
z = (x + y) * (x + y) * (x + y) + t
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_cse(trace)
self.assertExpected(str(trace))
@ -280,19 +271,38 @@ class TestJit(TestCase):
self.assertExpected(str(trace))
def test_inplace_flags(self):
class InplaceFn(Function):
@staticmethod
def forward(ctx, x):
ctx.mark_dirty(x)
return x.add_(1)
@staticmethod
def backward(ctx, go):
return go
class RegularFn(Function):
@staticmethod
def forward(ctx, x):
return x.add(1)
@staticmethod
def backward(ctx, go):
return go
x = Variable(torch.Tensor([0]), requires_grad=True)
trace = torch._C._tracer_enter((x,), 0)
y = x + 2
y.add_(2)
y.mul_(4)
y = y * 2
y = RegularFn.apply(x)
y = InplaceFn.apply(y)
y = InplaceFn.apply(y)
y = RegularFn.apply(y)
torch._C._tracer_exit((y,))
ops = [n for n in trace.graph().nodes() if n.kind() != 'Select']
for op in ops:
self.assertTrue(op.hasAttribute('__inplace'))
self.assertTrue(op.hasAttribute('inplace'))
inplace_flags = [False, True, True, False]
for op, is_inplace in zip(ops, inplace_flags):
self.assertEqual(op.i('__inplace'), is_inplace)
self.assertEqual(op.i('inplace'), is_inplace)
def test_inplace_check(self):
class MyInplaceFn(Function):
@ -548,7 +558,6 @@ class TestJit(TestCase):
assert(n_.i("some_value") == len(node.scalar_args()))
else:
n_ = g2.createClone(node, lambda x: g_to_g2[x])
assert(n_.kindOf("Offset") == "i")
g_to_g2[node] = g2.appendNode(n_)

View File

@ -426,15 +426,12 @@ class Variable(_C._VariableBase):
def bernoulli(self):
return Bernoulli.apply(self)
def __add__(self, other):
return self.add(other)
__radd__ = __add__
__radd__ = __add__ = _C._VariableBase.add
def __iadd__(self, other):
return self.add_(other)
def __sub__(self, other):
return self.sub(other)
__sub__ = _C._VariableBase.sub
def __isub__(self, other):
return self.sub_(other)
@ -442,9 +439,7 @@ class Variable(_C._VariableBase):
def __rsub__(self, other):
return -self + other
def __mul__(self, other):
return self.mul(other)
__rmul__ = __mul__
__rmul__ = __mul__ = _C._VariableBase.mul
def __imul__(self, other):
return self.mul_(other)
@ -454,9 +449,7 @@ class Variable(_C._VariableBase):
return NotImplemented
return self.matmul(other)
def __div__(self, other):
return self.div(other)
__truediv__ = __div__
__truediv__ = __div__ = _C._VariableBase.div
def __rdiv__(self, other):
return self.reciprocal() * other
@ -465,8 +458,7 @@ class Variable(_C._VariableBase):
def __idiv__(self, other):
return self.div_(other)
def __pow__(self, other):
return self.pow(other)
__pow__ = _C._VariableBase.pow
def __ipow__(self, other):
raise NotImplementedError("in-place pow not implemented")

View File

@ -11,6 +11,7 @@
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/python_function.h"
#include "torch/csrc/jit/generated/aten_dispatch.h"
#ifdef WITH_CUDA
#include "torch/csrc/jit/fusion_compiler.h"
#endif
@ -115,20 +116,28 @@ struct EmitNull : public Function {
};
};
// A hack that will let us implement some of the ops we care
// about before the major Python -> C++ Function migration
struct LambdaFunction : public Function {
LambdaFunction(const jit::TensorOp& op)
: LambdaFunction(op.num_inputs, op.op) {
this->name_ = op.name;
}
LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn)
: fn(fn) {
: fn_(fn) {
this->is_executable = true;
this->num_inputs = num_inputs;
}
virtual variable_list apply(const variable_list& inputs) {
return fn(inputs);
virtual std::string name() override {
return name_.size() == 0 ? "LambdaFunction" : name_;
}
std::function<variable_list(const variable_list&)> fn;
virtual variable_list apply(const variable_list& inputs) override {
return fn_(inputs);
}
std::string name_;
std::function<variable_list(const variable_list&)> fn_;
};
// Wraps a PythonOp and dispatches calls to Functions implemented in Python
@ -583,7 +592,7 @@ struct StageClosure {
IR_ELSEIF(Concat)
return std::make_shared<torch::autograd::Cat>(value->i(kaxis));
IR_ELSE()
throw std::runtime_error(std::string("unrecognized NodeKind: ") + symbolToString(node->kind()));
return std::make_shared<LambdaFunction>(getTensorOp(node));
IR_END()
}
@ -671,7 +680,7 @@ struct StageClosure {
// Roots for a call to the engine. The list contains function in this order:
// [ apply input roots | prev stage input roots | constant factory ]
function_list roots;
std::vector<VariableFlags> var_flags;
std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags;
// Output node
std::shared_ptr<Function> output;
@ -703,15 +712,14 @@ struct MultiStageClosure {
};
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc)
: AutogradClosure(desc, 0, {}) {}
: AutogradClosure(desc, 0) {}
// TODO: there's a lot processing involved in creating a new AutogradClosure instance,
// so it might be worth to keep a pool of unused instances (or at least their attrs)
// for all stages. We can't save saved_vars and saved_handles, but all callbacks
// can be made reusable.
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags &&f)
: Function(std::move(f))
, desc(desc)
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage)
: desc(desc)
, stage(stage) {
auto & stage_desc = desc->stages[stage];
@ -777,10 +785,10 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
// Validate inputs
auto num_inputs = inputs.size();
if (num_inputs != stage_closure.var_flags.size())
if (num_inputs != stage_closure.var_flags.first.size())
throw std::runtime_error("AutogradClosure received an incorrect number of inputs");
for (std::size_t i = 0; i < num_inputs; ++i) {
auto & flags = stage_closure.var_flags[i];
auto & flags = stage_closure.var_flags.first[i];
if (!flags.verify(inputs[i]))
throw std::runtime_error("AutogradClosure received inputs with different flags");
}
@ -797,16 +805,15 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
auto& engine = python::PythonEngine::getDefaultEngine();
engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks);
// See Note [Null-edge pruning]
auto relevant_inputs = filter(inputs, [](const Variable& var) { return var.defined() && var.requires_grad(); });
auto result = wrap_outputs(relevant_inputs, std::move(outputs), [this](FunctionFlags f) -> std::shared_ptr<Function> {
// Create the backward function lazily
auto make_grad_fn = [this]() -> std::shared_ptr<Function> {
if (this->stage == this->desc->stages.size() - 1) {
std::string msg = "JIT closure compiled only for ";
msg += std::to_string(this->stage);
msg += " backwards";
return std::make_shared<Error>(std::move(msg), std::move(f));
return std::make_shared<Error>(std::move(msg));
}
auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1, std::move(f)));
auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1));
// TODO: don't make a full copy of saved_* - copy only the things that bw needs
bw_fn->saved_vars = this->saved_vars;
bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()),
@ -824,7 +831,33 @@ variable_list AutogradClosure::apply(const variable_list& inputs) {
// was run, so it must have been executable).
bw_fn->is_executable = true;
return bw_fn;
});
};
// See Note [Null-edge pruning]
variable_list result;
auto num_outputs = outputs.size();
std::shared_ptr<Function> grad_fn;
JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size());
for (std::size_t i = 0; i < num_outputs; ++i) {
auto & flags = stage_closure.var_flags.second[i];
if (flags.requires_grad) {
if (!grad_fn) grad_fn = make_grad_fn();
result.push_back(make_variable(outputs[i], grad_fn));
} else {
result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile));
}
}
// If we created grad_fn for any of the outputs, we also need to fill in next_functions
if (grad_fn) {
for (auto & input : inputs) {
if (!input.requires_grad()) continue;
grad_fn->next_functions.emplace_back(
input.grad_fn() ? input.grad_fn() : input.grad_accumulator(),
input.output_nr());
}
}
captured_vars.clear();
captured_handles.clear();
outputs.clear();

View File

@ -28,7 +28,7 @@ struct AutogradClosure : public Function {
virtual variable_list apply(const variable_list& inputs) override;
private:
AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags&& f);
AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage);
variable_list rewrapInputs(const variable_list& inputs);

View File

@ -711,7 +711,7 @@ static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
sel->inferTypeFrom(output.data());
tracer::setValueTrace(tracing_state, output, sel);
}
this_expr->i_(k__inplace, is_inplace);
this_expr->i_(kinplace, is_inplace);
// See definition in function.cpp.
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};

View File

@ -64,7 +64,7 @@ _(perm) \
_(shape) \
_(axes) \
_(group) \
_(__inplace)
_(inplace)
enum BuiltinSymbol {
#define DEFINE_SYMBOL(s) \

View File

@ -177,7 +177,17 @@ void printAttributes(std::ostream & out, Node * n) {
case AttributeKind::t:
{
at::Tensor t = n->t(name);
if (t.numel() <= max_tensor_display_size) {
// 1-elem tensors are usually boxed scalars, so print them like it
if (t.numel() == 1) {
auto scalar = at::Scalar(t.view({})).local();
out << "{";
if (scalar.isFloatingPoint()) {
out << scalar.toDouble();
} else {
out << scalar.toLong();
}
out << "}";
} else if (t.numel() <= max_tensor_display_size) {
// TODO: This is awful code. Also it doesn't work on Windows.
std::ostringstream tensor_ss;
tensor_ss << t;

View File

@ -8,6 +8,16 @@
namespace torch { namespace jit {
namespace {
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
return &lhs.type() == &rhs.type() && lhs.equal(rhs);
}
bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
if (lhs.size() != rhs.size()) return false;
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
};
// Check whether two nodes have the same attributes in CSE.
@ -24,6 +34,8 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
auto lnames = lhs->attributeNames();
auto rnames = rhs->attributeNames();
std::sort(lnames.begin(), lnames.end());
std::sort(rnames.begin(), rnames.end());
if (lnames != rnames) return false;
for (auto name : lnames) {
@ -40,8 +52,13 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
COMPARE_ATTRIBUTEVALUE(is)
COMPARE_ATTRIBUTEVALUE(s)
COMPARE_ATTRIBUTEVALUE(ss)
case AttributeKind::t:
if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
break;
case AttributeKind::ts:
if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
default:
// NB: Comparison of nodes with tensor(s) or graph(s) will return false.
// NB: Comparison of nodes with graph(s) will return false.
return false;
}
@ -92,6 +109,8 @@ struct EqualNodeCSE {
}
};
} // anonymous namespace
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {

View File

@ -56,7 +56,7 @@ struct TraceEval : autograd::Eval {
setValueTrace(tracing_state, input, input_node);
input_node->inferTypeFrom(input.data());
}
tracing_state->var_flags.at(graph->stage()) = detail::getVarFlags(inputs);
tracing_state->var_flags.at(graph->stage()).first = detail::getVarFlags(inputs);
}
void exitTrace(const variable_list& inputs, const variable_list& outputs) {

View File

@ -200,7 +200,7 @@ inline std::shared_ptr<TracingState> enter(std::vector<TraceInput>&& trace_input
}
}
// TODO: this might not work with the way we handle buffers
state->var_flags[0] = detail::getVarFlags(inputs);
state->var_flags[0].first = detail::getVarFlags(inputs);
state->active = true;
state->inputs = inputs;
return state;
@ -214,6 +214,7 @@ inline void _exit(const std::shared_ptr<TracingState>& state, const variable_lis
state->graph->registerOutput(getValueTrace(state, output, true));
}
state->active = false;
state->var_flags[state->graph->stage()].second = detail::getVarFlags(outputs);
}
// Marks a backwards subgraph that should be traced as the next stage.

View File

@ -64,7 +64,8 @@ struct TracingState : public std::enable_shared_from_this<TracingState> {
// TODO: Perhaps, turn this into an owning reference. The buffers
// are persistent, so this won't lead to a leak.
std::unordered_map<void*, Node*> buffer_map;
std::vector<std::vector<VariableFlags>> var_flags;
// A pair of (input_flags, output_flags) for each stage
std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>> var_flags;
std::vector<function_list> output_edges;
std::mutex mutex;

View File

@ -469,7 +469,6 @@ class TraceForKey(object):
# It's important to always run DCE, because backward can create a lot of unnecessary nodes
_run_pass(torch._C._jit_pass_dce, complete_trace)
_run_pass(torch._C._jit_pass_onnx, complete_trace)
_run_pass(_passes._check_inplace, complete_trace)
if self.optimize:
_run_pass(torch._C._jit_pass_fuse, complete_trace)

View File

@ -7,5 +7,5 @@ def _check_inplace(trace):
graph = trace.graph()
for node in graph.nodes():
if node.kind() == 'PythonOp':
if node.i('__inplace'):
if node.i('inplace'):
raise RuntimeError("inplace {} not supported in the JIT".format(node.pyname()))