Allow Tensor lists to show up in symbolic differentiable graphs. (#16784)

Summary:
It is done by flattening all tensor lists that are inputs/outputs to the
graph into the inputs/outputs list in the autograd graph.

This is less desirable than simply allowing IValues to exist in the
inputs/outputs of autograd::Function but it is substantially less
intrusive.

CaptureList describes the variables captured for backward in a single class.
UnpackInstructs describes how the flattened inputs to backwards are re-packed into lists.
ailzhang

This PR is also part 2 of covering maskrcnn & bert AD formulas, following #16689.

Ops added in this PR:
```
cat
index
meshgrid
reshape
split
split_with_sizes
stack
unbind
```
I will also add a few perf numbers here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16784

Differential Revision: D14104063

Pulled By: ailzhang

fbshipit-source-id: 5ceadadfd67ccaac60c5fd6740786c5354e252b9
This commit is contained in:
Zachary DeVito
2019-04-10 18:12:38 -07:00
committed by Facebook Github Bot
parent 612998f2ee
commit 1abbee0f8e
5 changed files with 418 additions and 88 deletions

View File

@ -181,11 +181,11 @@ def method_tests():
('view', (S,), (S,), '1d', (True,)),
('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
('view', (), (1,), 'scalar_to_1d', (True,)),
('reshape', (S, S, S), (S * S, S),),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'),
('reshape', (S,), (S,), '1d'),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar'),
('reshape', (), (1,), 'scalar_to_1d'),
('reshape', (S, S, S), (S * S, S), '', (True,)),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
('reshape', (S,), (S,), '1d', (True,)),
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
('reshape', (), (1,), 'scalar_to_1d', (True,)),
('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
@ -726,10 +726,15 @@ def method_tests():
('unsqueeze', (), (0,), 'scalar', (True,), [0]),
('chunk', (S, S, S), (2,), '', (True, 'prim::ConstantChunk')),
('chunk', (S, S, S), (S, 1), 'dim', (True, 'prim::ConstantChunk'), [1]),
('split', (S, S, S), (2,)),
('split', (S, S, S), (S, 1), 'dim', (), [1]),
('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list'),
('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', (), [1]),
('split', (S, S, S), (2,), '', (True,)),
('split', (S, S, S), (S, 1), 'dim', (True,), [1]),
('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list',
(True, 'aten::split_with_sizes')),
('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim',
(True, 'aten::split_with_sizes'), [1]),
('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), '', (True,)),
('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3), 0],), 'size_0', (True, )),
('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'dim', (True, ), [1]),
('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]),
('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]),
('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]),

View File

@ -253,6 +253,14 @@ def enable_cpu_fuser(fn):
return wrapper
# helper function to get sum of List[Tensor]
def _sum_of_list(tensorlist):
s = 0
for t in tensorlist:
s += t.sum()
return s
class JitTestCase(TestCase):
_do_cuda_memory_leak_check = True
_restored_warnings = False
@ -3736,10 +3744,20 @@ a")
@torch.jit.script
def func2(x, y):
return torch.cat((x, x), y)
func2.debug_disable_autodiff_subgraph_inlining()
x = torch.rand([2, 2])
x = torch.rand([2, 2]).requires_grad_()
y = torch.tensor(1)
self.assertEqual(func2(x, y), torch.cat((x, x), y))
output = func2(x, y)
output_ref = torch.cat((x, x), y)
self.assertEqual(output, output_ref)
self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
grad = torch.autograd.grad(output.sum(), x)
grad_ref = torch.autograd.grad(output_ref.sum(), x)
self.assertEqual(grad, grad_ref)
def test_cat_lifts(self):
@torch.jit.script
@ -3757,6 +3775,73 @@ a")
for g in [foo.graph, foo2.graph, foo3.graph]:
FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
@unittest.skipIf(PY2, "Requires python 3")
def test_stack(self):
@torch.jit.script
def func(x):
return torch.stack((x, x), dim=1)
x = torch.rand(10, 10)
self.assertEqual(func(x), torch.stack((x, x), dim=1))
@torch.jit.script
def func2(x, y):
return torch.stack((x, y), dim=0)
func2.debug_disable_autodiff_subgraph_inlining()
x = torch.randn([2, 2]).requires_grad_()
y = torch.randn([2, 2]).requires_grad_()
output = func2(x, y)
output_ref = torch.stack((x, y), 0)
self.assertEqual(output, output_ref)
self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
grads = torch.autograd.grad(output.sum(), (x, y))
grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
self.assertEqual(grads, grads_ref)
def test_unbind(self):
@torch.jit.script
def func(x, y):
# type: (Tensor, int) -> List[Tensor]
return torch.unbind(x, y)
func.debug_disable_autodiff_subgraph_inlining()
x = torch.rand([2, 2]).requires_grad_()
y = 0
outputs = func(x, y)
outputs_ref = torch.unbind(x, dim=y)
self.assertEqual(outputs, outputs_ref)
self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], [])
grad = torch.autograd.grad(_sum_of_list(outputs), x)
grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
self.assertEqual(grad, grad_ref)
def test_meshgrid(self):
@torch.jit.script
def func(a):
# type: (List[Tensor]) -> List[Tensor]
return torch.meshgrid(a)
func.debug_disable_autodiff_subgraph_inlining()
a = torch.tensor([1.0, 2, 3]).requires_grad_()
b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
inputs = [a, b]
outputs_ref = torch.meshgrid(inputs)
outputs = func(inputs)
self.assertEqual(outputs, outputs_ref)
self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], [])
grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
self.assertEqual(grads, grads_ref)
def test_list_literal(self):
def reassign():
x = [1]
@ -11933,10 +12018,16 @@ EXCLUDE_SCRIPT = {
# chunk returns a list in scripting and we don't unpack the list,
# Thus it won't be replaced by ConstantChunk and run AD.
# It's explicitly checked in test_chunk_constant_script_ad
# Similary for split, it's replaced by split_with_sizes in tracing,
# but we don't have AD formula for aten::split(Tensor, int[], int),
# an op registered in JIT so AD is not triggered in scripting.
EXCLUDE_SCRIPT_AD_CHECK = {
'test_chunk',
'test_chunk_dim',
'test_chunk_dim_neg0',
'test_split_size_list',
'test_split_size_list_dim',
'test_split_size_list_dim_neg0',
}
EXCLUDE_PYTHON_PRINT = {

View File

@ -75,53 +75,157 @@ struct ExecutionPlan {
std::shared_ptr<Graph> graph;
};
struct DifferentiableGraphBackward : public autograd::Function {
DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size)
: executor(std::move(executor)) {
is_var_capture.reserve(capture_size);
var_captures.reserve(capture_size);
ivalue_captures.reserve(capture_size);
struct CaptureList {
CaptureList(size_t capture_size) {
capture_types_.reserve(capture_size);
var_captures_.reserve(capture_size); // var_captures_.size() might be greater than capture_size
ivalue_captures_.reserve(capture_size);
}
void captureTensor(const at::Tensor& tensor, bool is_output) {
var_captures_.emplace_back(Variable(tensor), is_output);
}
void capture(const IValue& val, bool is_output) {
if (val.isTensor()) {
capture_types_.emplace_back(CAPTURE_TENSOR);
captureTensor(val.toTensor(), is_output);
} else if (val.isTensorList()) {
// For TensorList, we have to flatten it to Tensors during saving and
// unflatten it back to TensorList when using it in backward apply().
// This is to avoid any implicit mutation to TensorList happened
// between forward & backward.
capture_types_.emplace_back(CAPTURE_LIST);
const std::vector<at::Tensor>& tensors = val.toTensorListRef();
sizes_.push_back(tensors.size());
for (const at::Tensor& tensor: tensors) {
captureTensor(tensor, is_output);
}
} else {
capture_types_.emplace_back(CAPTURE_IVALUE);
ivalue_captures_.push_back(val);
}
}
size_t size() const {
return capture_types_.size();
}
void unpack(Stack & stack, const std::shared_ptr<autograd::Function>& saved_for) {
auto var_capture_it = var_captures_.begin();
auto ivalue_capture_it = ivalue_captures_.begin();
auto size_it = sizes_.begin();
for (Capture capture_type : capture_types_) {
switch(capture_type) {
case CAPTURE_TENSOR: {
stack.emplace_back(var_capture_it->unpack(saved_for));
++var_capture_it;
} break;
case CAPTURE_LIST: {
std::vector<at::Tensor> lst;
auto size = *size_it++;
for (size_t i = 0; i < size; i++) {
lst.emplace_back(var_capture_it->unpack(saved_for));
var_capture_it++;
}
stack.emplace_back(TensorList::create(std::move(lst)));
} break;
case CAPTURE_IVALUE: {
stack.push_back(*ivalue_capture_it++);
} break;
}
}
}
private:
enum Capture: uint8_t {
CAPTURE_TENSOR,
CAPTURE_LIST,
CAPTURE_IVALUE,
};
std::vector<Capture> capture_types_;
std::vector<autograd::SavedVariable> var_captures_;
std::vector<IValue> ivalue_captures_;
std::vector<size_t> sizes_;
};
// how do we turn a flattened list of tensors back into the ivalues that
// the DifferentiableGraphBackward expects
struct UnpackInstructions {
UnpackInstructions(size_t num_inputs) {
insts_.reserve(num_inputs);
}
void pushTensor() {
insts_.emplace_back(PUSH_TENSOR);
}
void pushTensorList(size_t size) {
insts_.emplace_back(PUSH_LIST);
sizes_.push_back(size);
}
void unpack(variable_list&& inputs, Stack& stack) {
auto input_it = std::make_move_iterator(inputs.begin());
auto sizes_it = sizes_.begin();
for(Inst inst : insts_) {
switch(inst) {
case PUSH_TENSOR: {
at::Tensor t = *input_it++;
stack.emplace_back(std::move(t));
} break;
case PUSH_LIST: {
std::vector<at::Tensor> lst(input_it, input_it + *sizes_it++);
stack.emplace_back(TensorList::create(std::move(lst)));
} break;
}
}
}
private:
enum Inst : uint8_t {
PUSH_TENSOR,
PUSH_LIST, // consumes one size
};
std::vector<Inst> insts_;
std::vector<size_t> sizes_;
};
struct DifferentiableGraphBackward : public autograd::Function {
DifferentiableGraphBackward(GraphExecutor executor, size_t input_size, size_t capture_size)
: executor(std::move(executor))
, captures_(capture_size)
, input_instructions_(input_size) {}
variable_list apply(variable_list&& inputs) override {
Stack stack;
stack.reserve(is_var_capture.size() + inputs.size());
stack.insert(
stack.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
auto var_capture_it = var_captures.begin();
auto ivalue_capture_it = ivalue_captures.begin();
for (bool is_var : is_var_capture) {
if (is_var) {
stack.emplace_back(var_capture_it->unpack(this->shared_from_this()));
++var_capture_it;
} else {
stack.push_back(*ivalue_capture_it);
++ivalue_capture_it;
}
}
stack.reserve(captures_.size() + inputs.size());
input_instructions_.unpack(std::move(inputs), stack);
captures_.unpack(stack, shared_from_this());
executor.run(stack);
AT_ASSERT(stack.size() == num_outputs());
// NB: stack.size() == num_outputs() is not always true
// after we added TensorList support.
// Example: aten::stack(Tensor[] tensors, int) where
// tensors = [x, x]
// Here stack.size()[=1] with a TensorList IValue of
// backward graph output.
// num_outputs()[=2], however, is the number of outputs of
// grad_fn (an autograd::Function). grad_fn's outputs are
// grads with regard to Tensor/Variables `x`, but not
// graph input TensorList [x, x]. These two grads will
// be accumulated to x.grad later using autograd::InputBuffer.
variable_list outputs;
outputs.reserve(num_outputs());
for (size_t i = 0; i < num_outputs(); ++i) {
// Input grad can also be None even if it requires grad
// Example: `other` in expand_as(self, other)
if (should_compute_output(i) && !stack[i].isNone()) {
auto output = std::move(stack[i]).toTensor();
const auto& edge = next_edge(i);
if (output.defined()) {
outputs.emplace_back(std::move(output));
} else if (edge.is_valid()) {
outputs.emplace_back(
edge.function->input_metadata(edge.input_nr).zeros_like());
} else {
outputs.emplace_back();
size_t output_index = 0;
for (IValue& v : stack) {
if (v.isTensorList()) {
for(at::Tensor tensor : v.toTensorListRef()) {
produceOutput(output_index++, std::move(tensor), outputs);
}
} else if (v.isTensor()) {
produceOutput(output_index++, std::move(v).toTensor(), outputs);
} else {
// Input grad can also be None even if it requires grad
// Example: `other` in expand_as(self, other)
outputs.emplace_back();
}
}
@ -129,24 +233,72 @@ struct DifferentiableGraphBackward : public autograd::Function {
}
void capture(const IValue& val, bool is_output) {
const bool is_tensor = val.isTensor();
is_var_capture.push_back(is_tensor);
if (is_tensor) {
var_captures.emplace_back(Variable(val.toTensor()), is_output);
captures_.capture(val, is_output);
}
void addOutputForTensor(const at::Tensor& tensor) {
auto v = Variable(tensor);
add_next_edge(
v.defined() ? v.gradient_edge() : autograd::Edge{});
}
void addOutputForIValue(const IValue& value) {
if (value.isTensorList()){
for(const at::Tensor& tensor : value.toTensorListRef()) {
addOutputForTensor(tensor);
}
} else {
ivalue_captures.push_back(val);
addOutputForTensor(value.toTensor());
}
}
void addInputVariable(Variable output) {
// NB: since our requires_grad setting is only a heuristic we might end
// up wanting to differentiate through integral tensors, which is
// generally a hard error in autograd.
if (at::isFloatingType(output.type().scalarType())) {
autograd::create_gradient_edge(output, shared_from_this());
output.set_requires_grad(true);
} else {
add_input_metadata(autograd::Function::undefined_input{});
}
}
void addInputIValue(const IValue& v) {
if (v.isTensorList()) {
const std::vector<at::Tensor>& tensors = v.toTensorListRef();
input_instructions_.pushTensorList(tensors.size());
for (const at::Tensor& tensor : tensors) {
addInputVariable(tensor);
}
} else if (v.isTensor()) {
input_instructions_.pushTensor();
addInputVariable(v.toTensor());
}
}
private:
void produceOutput(size_t i, at::Tensor output, variable_list& outputs) {
if (should_compute_output(i)) {
const auto& edge = next_edge(i);
if (output.defined()) {
outputs.emplace_back(std::move(output));
} else if (edge.is_valid()) {
outputs.emplace_back(
edge.function->input_metadata(edge.input_nr).zeros_like());
} else {
outputs.emplace_back();
}
} else {
outputs.emplace_back();
}
}
private:
friend struct ExecutionPlan;
GraphExecutor executor;
// INVARIANT: is_var_capture.size() == var_captures.size() +
// ivalue_captures.size()
std::vector<bool> is_var_capture;
std::vector<autograd::SavedVariable> var_captures;
std::vector<IValue> ivalue_captures;
CaptureList captures_;
UnpackInstructions input_instructions_;
};
// an optimized way of executing the subgraph computed directly on
@ -166,6 +318,7 @@ struct DifferentiableGraphOp {
int operator()(Stack& stack) const {
auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
grad_executor,
grad.df_input_vjps.size(),
grad.df_input_captured_inputs.size() +
grad.df_input_captured_outputs.size());
@ -174,9 +327,7 @@ struct DifferentiableGraphOp {
// hook up the outputs of df to the gradient functions of the inputs that
// require gradients
for (auto idx : grad.df_output_vjps) {
auto v = Variable(inputs[idx].toTensor());
grad_fn->add_next_edge(
v.defined() ? v.gradient_edge() : autograd::Edge{});
grad_fn->addOutputForIValue(inputs[idx]);
}
captureInputs(*grad_fn, inputs);
}
@ -194,24 +345,7 @@ struct DifferentiableGraphOp {
// this is currently intentionally not done here so we can get an idea of
// our perf before introducing overhead for correctness
for (auto idx : grad.df_input_vjps) {
// Note: we have to set this up in place, or we have to throw away and
// reallocate variables that were already created in wrapTensors. We
// should add an API for this.
// XXX: undefined tensor syntax in autograd
Variable output;
if (!outputs[idx].isNone()) {
output = outputs[idx].toTensor();
}
// NB: since our requires_grad setting is only a heuristic we might end
// up wanting to differentiate through integral tensors, which is
// generally a hard error in autograd.
if (at::isFloatingType(output.scalar_type())) {
autograd::create_gradient_edge(output, grad_fn);
output.set_requires_grad(true);
} else {
grad_fn->add_input_metadata(autograd::Function::undefined_input{});
}
grad_fn->addInputIValue(outputs[idx]);
}
captureOutputs(*grad_fn, outputs);
// drop the temporary outputs so that we return the same number of
@ -225,6 +359,26 @@ struct DifferentiableGraphOp {
private:
friend GraphExecutor* detail::getGradExecutor(Operation& op);
void detach(at::Tensor& t) const {
if (t.defined()) {
t = autograd::as_variable_ref(t).detach();
}
}
void detach(IValue& v) const {
if(v.isTensor()) {
auto t = std::move(v).toTensor();
detach(t);
v = IValue{t};
} else if(v.isTensorList()) {
std::vector<at::Tensor> lst = v.toTensorListRef();
for(at::Tensor& t : lst) {
detach(t);
}
v = TensorList::create(std::move(lst));
}
}
void detachVariables(Stack& stack) const {
// It would be nice to use an ArrayRef here, but unfortunately those can
// only return const references, so we need to do a bunch of indexing
@ -232,12 +386,7 @@ struct DifferentiableGraphOp {
const int64_t stack_size = stack.size();
const int64_t stack_offset = stack_size - num_inputs;
for (int64_t i = stack_offset; i < stack_size; ++i) {
auto& v = stack[i];
if (!v.isTensor())
continue;
auto t = std::move(v).toTensor();
v = IValue{t.defined() ? autograd::as_variable_ref(t).detach()
: std::move(t)};
detach(stack[i]);
}
}
// Capture (save) inputs that would be required to subsequently run backwards

View File

@ -18,7 +18,8 @@ void specializeAutogradZero(Graph& g) {
const auto& tp = input->type();
if (tp->isSubtypeOf(AutogradZeroTensorType::get())) {
state[input] = State::Zero;
} else if (tp->isSubtypeOf(TensorType::get())) {
} else if (tp->isSubtypeOf(TensorType::get())
|| tp->isSubtypeOf(ListType::ofTensors())) {
state[input] = State::Nonzero;
} else {
state[input] = State::Unknown;

View File

@ -6,7 +6,6 @@ namespace {
std::mutex lock;
const std::vector<std::string> functions = {
R"(
#### HELPER FUNCTIONS ###
#### PREFIX: AD_ ###
#### SCHEMA NOT SAVED IN CACHE ###
@ -475,6 +474,91 @@ const std::vector<std::string> functions = {
return self * other, backward
def reshape(self,
shape: List[int]):
self_size = self.size()
def backward(grad_output):
grad_self = grad_output.reshape(self_size)
return grad_self, None
return torch.reshape(self, shape), backward
def split(self,
split_size: int,
dim: int):
def backward(grad_outputs: List[Tensor]):
grad_self = torch.cat(grad_outputs, dim)
return grad_self, None, None
return torch.split(self, split_size, dim), backward
def split_with_sizes(self,
split_sizes: List[int],
dim: int=0):
def backward(grad_outputs: List[Tensor]):
size = len(grad_outputs)
grad_self = torch.cat(grad_outputs, dim)
return grad_self, None, None
return torch.split_with_sizes(self, split_sizes, dim), backward
def stack(tensors: List[Tensor],
dim: int=0):
def backward(grad_output):
grad_tensors = torch.unbind(grad_output, dim)
return grad_tensors, None
return torch.stack(tensors, dim), backward
def unbind(self,
dim: int=0):
def backward(grad_outputs: List[Tensor]):
grad_self = torch.stack(grad_outputs, dim)
return grad_self, None
return torch.unbind(self, dim), backward
def cat(tensors: List[Tensor],
dim: int=0):
size = len(tensors)
split_sizes = [0] * size
for i in range(size):
if tensors[i].numel() > 0:
split_sizes[i] = tensors[i].size()[dim]
def backward(grad_output):
grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim)
return grad_tensors, None
return torch.cat(tensors, dim), backward
def index(self,
indices: List[Tensor]):
def backward(grad_output):
grad_self = torch.zeros_like(self).index_put_(indices, grad_output, True)
return grad_self, None
return torch.index(self, indices), backward
def meshgrid(tensors: List[Tensor]):
size = len(tensors)
sizes = [0] * size
for i in range(size):
if tensors[i].dim() != 0:
sizes[i] = tensors[i].size()[0]
def backward(grad_outputs: List[Tensor]):
grads_tensors = []
for i in range(size):
view_shape = [1] * size
if sizes[i] == 0:
view_shape[i] = 1
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(()))
else:
view_shape[i] = sizes[i]
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]]))
return grads_tensors
return torch.meshgrid(tensors), backward
def mv(self, vec):
def backward(grad_output):
return grad_output.ger(vec), self.t().mv(grad_output)