Allow passing dicts as trace inputs. (#18092)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18092

Previously, tracing required all inputs to be either tensors,
or tuples of tensor. Now, we allow users to pass dicts as well.

Differential Revision: D14491795

fbshipit-source-id: 7a2df218e5d00f898d01fa5b9669f9d674280be3
This commit is contained in:
Eric Faust
2019-04-18 23:48:59 -07:00
committed by Facebook Github Bot
parent 9034b66f14
commit 593bb145ce
12 changed files with 274 additions and 39 deletions

View File

@ -287,6 +287,18 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
}
}
return static_cast<TypePtr>(TupleType::create(elements));
} else if (t1->cast<DictType>() && t2->cast<DictType>()) {
auto dict1 = t1->cast<DictType>();
auto dict2 = t2->cast<DictType>();
auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType());
auto unshaped_value1 = unshapedType(dict1->getValueType());
auto unshaped_value2 = unshapedType(dict2->getValueType());
auto unified_value = tryEitherIsTheSuperType(unshaped_value1, unshaped_value2);
if (!unified_key || !unified_value) {
return c10::nullopt;
}
return DictType::create(*unified_key, *unified_value);
}
return c10::nullopt;

View File

@ -60,9 +60,17 @@ variable_list get_grad_outputs(const variable_list& vars) {
std::shared_ptr<Graph> trace(
const ADTestSpec& test,
const variable_list& vars_in) {
Stack input_vars = fmap<IValue>(vars_in);
std::vector<TypePtr> input_types;
input_types.reserve(input_vars.size());
for (auto i = 0; i < input_vars.size(); i++) {
input_types.push_back(TensorType::get());
}
auto input_typeptr = TupleType::create(std::move(input_types));
std::shared_ptr<tracer::TracingState> state;
Stack trace_stack_in;
std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in));
std::tie(state, trace_stack_in) =
tracer::enter(tracer::TypedStack(input_vars, input_typeptr));
variable_list trace_vars_in = fmap(
trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); });
auto trace_vars_out = test(trace_vars_in);

View File

@ -12,6 +12,7 @@ from contextlib import contextmanager
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
from torch.autograd.function import _nested_map
from torch.onnx import OperatorExportTypes
from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
@ -19,7 +20,7 @@ from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
freeze_rng_state, set_rng_seed, slowTest
from common_nn import module_tests, new_module_tests, criterion_tests
from textwrap import dedent
from functools import wraps
from functools import wraps, reduce
import os
import io
import sys
@ -535,9 +536,24 @@ class JitTestCase(TestCase):
if input_tensors is None:
input_tensors = reference_tensors
def do_input_map(fn, input):
return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
def flatten_inputs(inputs):
def input_reduce(input, fn, acc):
if isinstance(input, torch.Tensor):
fn(input, acc)
elif isinstance(input, dict):
reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
else:
reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
return acc
return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
nograd_inputs = reference_tensors
if inputs_require_grads:
recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
flattened_recording_inputs = flatten_inputs(recording_inputs)
else:
recording_inputs = reference_tensors
@ -558,12 +574,12 @@ class JitTestCase(TestCase):
# test single grad case
outputs = func(*recording_inputs)
if inputs_require_grads:
grads = torch.autograd.grad(allSum(outputs), recording_inputs,
grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
allow_unused=allow_unused)
outputs_ge = ge(*recording_inputs)
if inputs_require_grads:
grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
@ -574,25 +590,25 @@ class JitTestCase(TestCase):
outputs = func(*recording_inputs)
l1 = allSum(outputs)
if inputs_require_grads:
grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
allow_unused=allow_unused)
if inputs_require_grads:
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
if inputs_require_grads:
recording_inputs = [Variable(t, requires_grad=True)
for t in reference_tensors]
recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
flattened_recording_inputs = flatten_inputs(recording_inputs)
outputs_ge = ge(*recording_inputs)
l1_ge = allSum(outputs_ge)
if inputs_require_grads:
grads_ge = torch.autograd.grad(
l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
if inputs_require_grads:
l2_ge = (allSum(grads_ge) * l1_ge)
grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
@ -1580,8 +1596,6 @@ graph(%x : Tensor,
self.checkTrace(fn, (torch.randn(2, 2),))
# TODO: implement
@unittest.expectedFailure
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
@ -1592,6 +1606,66 @@ graph(%x : Tensor,
inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
self.checkTrace(fn, inputs)
def test_input_dict_empty(self):
def test(d):
pass
with self.assertRaises(RuntimeError):
self.checkTrace(test, {})
def test_input_dict_flattens(self):
class Test(torch.nn.Module):
def forward(self, d):
return d['x'] + d['y']
inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
module = torch.jit.trace(Test(), inputs)
FileCheck().check('aten::values').check('prim::ListUnpack').run(str(module.graph))
def test_input_dict_flattens_recursive(self):
class Test(torch.nn.Module):
def forward(self, d):
# Use both to avoid getting optimized away
a = d['x'][0]
b, c = d['y']
return a + b
inputs = {'x': (torch.rand(2, 2), torch.rand(2, 2)), 'y': (torch.ones(1, 1), torch.ones(2, 1))}
module = torch.jit.trace(Test(), inputs)
FileCheck().check('aten::values') \
.check('prim::ListUnpack') \
.check_count('prim::TupleUnpack', 2) \
.run(str(module.graph))
def test_input_dict_checkTrace_mut(self):
def test(d):
d['x'].tanh_()
return d['x']
inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
self.checkTrace(test, (inputs,), inputs_require_grads=False)
def test_input_dict_unify(self):
def test(d):
return d['int'], d['float']
inputs = {'int': torch.ones((2, 2), dtype=torch.int32),
'float': torch.ones((2, 2), dtype=torch.float32)}
self.checkTrace(test, (inputs,), inputs_require_grads=False)
def test_input_tuple_of_dicts(self):
def test(t):
d = t[0]
return d['x']['y']
inputs = {'x': {'y': torch.rand(2, 3)}}
self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
def test_input_dict_of_dicts(self):
def test(d):
return d['x']['y']
nested_input = {'y': torch.rand(2, 3)}
unified_nested = {'y': torch.rand(3, 2)}
inputs = {'x': nested_input, 'force_unify': unified_nested}
self.checkTrace(test, (inputs,), allow_unused=True)
# TODO: adapt to a GraphExecutor test
@unittest.skip("Need to instrument GraphExecutors a bit more")
def test_flags(self):

View File

@ -255,6 +255,8 @@ def _nested_map(condition, fn, condition_msg=None):
return None
elif isinstance(obj, (list, tuple)):
return type(obj)(_map(x) for x in obj)
elif isinstance(obj, dict):
return {x : _map(obj[x]) for x in obj}
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
@ -284,6 +286,11 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None,
for o in obj:
for var in _iter(o):
yield var
elif isinstance(obj, dict):
# We only accept primitive key types, so we needn't inspect them
for o in obj.values():
for var in _iter(o):
yield var
elif allow_unknown:
yield obj
else:

View File

@ -35,24 +35,85 @@ namespace jit {
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.
inline IValue toIValue(py::handle input) {
using tracer::TypedStack;
struct TypedIValue : public std::pair<IValue, TypePtr> {
using pair::pair;
IValue& ivalue() {
return this->first;
}
TypePtr& type() {
return this->second;
}
};
inline TypedIValue toDictKeyIValue(py::handle key) {
if (py::isinstance<py::str>(key)) {
return TypedIValue(ConstantString::create(py::cast<std::string>(key)),
StringType::create());
} else if (PyLong_Check(key.ptr())) {
return TypedIValue(py::cast<int64_t>(key), IntType::create());
} else if (PyFloat_Check(key.ptr())) {
return TypedIValue(py::cast<double>(key), FloatType::create());
} else {
AT_ERROR("Dictionary inputs may only have string, int, or float keys");
}
}
inline TypedIValue toTypedIValue(py::handle input) {
if (THPVariable_Check(input.ptr())) {
auto ten = py::cast<at::Tensor>(input);
if (ten.is_sparse()) {
AT_ERROR("sparse tensors not supported");
}
return ten;
return TypedIValue(ten, CompleteTensorType::create(ten));
} else if (six::isTuple(input)) {
py::tuple input_tuple = py::cast<py::tuple>(input);
Stack s;
std::vector<TypePtr> t;
s.reserve(input_tuple.size());
t.reserve(input_tuple.size());
for (py::handle elem : input_tuple) {
s.push_back(toIValue(elem));
auto info = toTypedIValue(elem);
s.push_back(info.first);
t.push_back(info.second);
}
return Tuple::create(s);
return TypedIValue(Tuple::create(s), TupleType::create(t));
} else if (PyDict_Check(input.ptr())) {
// Check to make sure we can generate useful input/output types
auto dict = py::cast<py::dict>(input);
at::ivalue::UnorderedMap elems;
size_t len = py::len(dict);
if (!len) {
AT_ERROR("Dictionary inputs must have entries.");
}
elems.reserve(len);
TypePtr keyType = nullptr;
TypePtr valueType = nullptr;
for (auto entry : dict) {
auto keyInfo = toDictKeyIValue(entry.first);
auto valInfo = toTypedIValue(entry.second);
if (!keyType) {
keyType = keyInfo.second;
valueType = valInfo.second;
} else {
auto unifiedKey = unifyTypes(keyType, keyInfo.second);
auto unifiedValue = unifyTypes(valueType, valInfo.second);
if (!unifiedKey || !unifiedValue) {
AT_ERROR("Dictionary inputs to traced functions must have consistent type");
}
keyType = *unifiedKey;
valueType = *unifiedValue;
}
elems.insert(std::make_pair(keyInfo.first, valInfo.first));
}
return TypedIValue(at::ivalue::GenericDict::create(std::move(elems)),
DictType::create(keyType, valueType));
} else {
throw std::runtime_error(c10::str(
"Only tensors and (possibly nested) tuples of tensors are supported ",
"Only tensors and (possibly nested) tuples of tensors or dicts are supported ",
"as inputs or outputs of traced functions",
", but instead got value of type ",
py::str(input.get_type().attr("__name__")),
@ -62,10 +123,19 @@ inline IValue toIValue(py::handle input) {
}
}
inline IValue toIValue(py::handle input) {
return toTypedIValue(input).ivalue();
}
inline Stack toStack(const py::tuple& inputs) {
return toIValue(inputs).toTuple()->elements();
}
inline TypedStack toTypedStack(const py::tuple& inputs) {
auto info = toTypedIValue(inputs);
return TypedStack(info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
}
inline IValue toIValue(
py::handle obj,
const TypePtr& type,

View File

@ -38,7 +38,7 @@ std::string getPythonInterpreterStackTrace() {
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
const py::function& func,
Stack trace_inputs,
TypedStack trace_inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
const c10::optional<size_t>& num_real_inputs) {
@ -145,7 +145,7 @@ void initPythonTracerBindings(PyObject* module) {
m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
m.def("_tracer_enter", [](py::args trace_inputs) {
return tracer::enter(toStack(trace_inputs));
return tracer::enter(toTypedStack(trace_inputs));
});
m.def("_tracer_exit", [](py::tuple var_outputs) {
tracer::exit(toStack(var_outputs));

View File

@ -21,7 +21,7 @@ Node* preRecordPythonTrace(
std::shared_ptr<Graph> createGraphByTracing(
const py::function& func,
Stack inputs,
TypedStack inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
const c10::optional<size_t>& num_real_inputs = c10::nullopt);

View File

@ -1540,15 +1540,33 @@ int dictKeys(Stack& stack) {
return 0;
}
int dictValues(Stack& stack) {
auto dict = pop(stack).toGenericDictRef();
std::vector<IValue> values;
template <typename Elem>
std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) {
std::vector<Elem> values;
values.reserve(dict.size());
for (auto item : dict) {
values.push_back(item.second);
values.push_back(item.second.to<Elem>());
}
push(stack, IValue(values));
return 0;
return values;
}
Operation dictValues(const Node* n) {
auto outputType = n->output()->type()->expect<ListType>();
return [=](Stack& stack) -> int {
auto dict = pop(stack).toGenericDictRef();
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
push(stack, makeListForDictValues<at::Tensor>(dict));
} else if (outputType->getElementType() == IntType::get()) {
push(stack, makeListForDictValues<int64_t>(dict));
} else if (outputType->getElementType() == FloatType::get()) {
push(stack, makeListForDictValues<double>(dict));
} else if (outputType->getElementType() == BoolType::get()) {
push(stack, makeListForDictValues<bool>(dict));
} else {
push(stack, makeListForDictValues<IValue>(dict));
}
return 0;
};
}
int dictIndex(Stack& stack) {

View File

@ -939,13 +939,19 @@ void initJitScriptBindings(PyObject* module) {
// this was ensured in python before calling this function
std::vector<Slot> parameters;
gatherParametersAndBuffers(parameters, *self);
Stack inputs = toStack(input_tuple);
for (const Slot& param : parameters) {
inputs.emplace_back(param.value());
auto typed_inputs = toTypedStack(input_tuple);
if (parameters.size() > 0) {
auto inputs = typed_inputs.stack();
auto input_types = typed_inputs.types()->elements().vec();
for (const Slot& param : parameters) {
inputs.emplace_back(param.value());
input_types.push_back(incompleteInferTypeFrom(param.value()));
}
typed_inputs = TypedStack(inputs, TupleType::create(input_types));
}
auto graph = tracer::createGraphByTracing(
func,
inputs,
typed_inputs,
var_lookup_fn,
force_outplace,
input_tuple.size());

View File

@ -201,7 +201,7 @@ Value* getNestedOutputTrace(
// Start tracing, treating 'inputs' as inputs to the trace, which can be
// varied on subsequent invocations of the trace. Any other variables
// will be treated as constants.
std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) {
if (isTracing()) {
AT_ERROR("Tracing can't be nested");
}
@ -234,17 +234,34 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
}
return Tuple::create(std::move(elems));
} else if (auto dict_type = type->cast<DictType>()) {
auto elem_pairs = input.toGenericDict()->elements();
auto unpack_to_list = state->graph->insert(aten::values, {value});
auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size());
auto unpack_node = state->graph->insertNode(list_unpack);
auto elem_values = unpack_node->outputs();
AT_ASSERT(elem_pairs.size() == elem_values.size());
size_t i = 0;
for (const auto &pair : elem_pairs) {
elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]);
}
return c10::ivalue::GenericDict::create(std::move(elem_pairs));
} else {
AT_ERROR(
"Only tensors or tuples of tensors can be inputs to traced functions. Got ",
type);
"Only tensors or (possibly nested) dict or tuples of tensors can be "
"inputs to traced functions. Got ", type);
}
};
for (IValue& input : inputs) {
size_t i = 0;
auto input_types = inputs.types()->elements();
for (IValue& input : inputs.stack()) {
input = add_input(
input, incompleteInferTypeFrom(input), state->graph->addInput());
input, input_types[i++], state->graph->addInput());
}
return std::make_pair(state, inputs);
return std::make_pair(state, inputs.stack());
}
// Exit a trace, treating 'outputs' as the outputs of the trace. These

View File

@ -65,7 +65,30 @@ TORCH_API Value* getNestedOutputTrace(
const std::shared_ptr<TracingState>& state,
const IValue& iv);
TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs);
struct TypedStack : public std::pair<Stack, TupleTypePtr>
{
using pair::pair;
// NB: The inherited default constructor gives nullptr for |type|,
// so we provide a saner one.
TypedStack()
: pair({}, TupleType::create({}))
{}
Stack& stack() {
return this->first;
}
TupleTypePtr& types() {
return this->second;
}
size_t size() {
auto s = stack().size();
AT_ASSERT(s == types()->elements().size());
return s;
}
};
TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs);
TORCH_API void exit(const Stack& outputs);

View File

@ -656,7 +656,7 @@ def trace(func,
return func
executor_options = {'optimize': bool(optimize)}
# Special case for common case of passing a single Tensor
if isinstance(example_inputs, torch.Tensor):
if isinstance(example_inputs, (torch.Tensor, dict)):
example_inputs = (example_inputs,)
# done primarily so that weird iterables fail here and not pybind11 code
elif not isinstance(example_inputs, tuple):