mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow passing dicts as trace inputs. (#18092)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18092 Previously, tracing required all inputs to be either tensors, or tuples of tensor. Now, we allow users to pass dicts as well. Differential Revision: D14491795 fbshipit-source-id: 7a2df218e5d00f898d01fa5b9669f9d674280be3
This commit is contained in:
committed by
Facebook Github Bot
parent
9034b66f14
commit
593bb145ce
@ -287,6 +287,18 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
|
||||
}
|
||||
}
|
||||
return static_cast<TypePtr>(TupleType::create(elements));
|
||||
} else if (t1->cast<DictType>() && t2->cast<DictType>()) {
|
||||
auto dict1 = t1->cast<DictType>();
|
||||
auto dict2 = t2->cast<DictType>();
|
||||
|
||||
auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType());
|
||||
auto unshaped_value1 = unshapedType(dict1->getValueType());
|
||||
auto unshaped_value2 = unshapedType(dict2->getValueType());
|
||||
auto unified_value = tryEitherIsTheSuperType(unshaped_value1, unshaped_value2);
|
||||
if (!unified_key || !unified_value) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
return DictType::create(*unified_key, *unified_value);
|
||||
}
|
||||
|
||||
return c10::nullopt;
|
||||
|
||||
@ -60,9 +60,17 @@ variable_list get_grad_outputs(const variable_list& vars) {
|
||||
std::shared_ptr<Graph> trace(
|
||||
const ADTestSpec& test,
|
||||
const variable_list& vars_in) {
|
||||
Stack input_vars = fmap<IValue>(vars_in);
|
||||
std::vector<TypePtr> input_types;
|
||||
input_types.reserve(input_vars.size());
|
||||
for (auto i = 0; i < input_vars.size(); i++) {
|
||||
input_types.push_back(TensorType::get());
|
||||
}
|
||||
auto input_typeptr = TupleType::create(std::move(input_types));
|
||||
std::shared_ptr<tracer::TracingState> state;
|
||||
Stack trace_stack_in;
|
||||
std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in));
|
||||
std::tie(state, trace_stack_in) =
|
||||
tracer::enter(tracer::TypedStack(input_vars, input_typeptr));
|
||||
variable_list trace_vars_in = fmap(
|
||||
trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); });
|
||||
auto trace_vars_out = test(trace_vars_in);
|
||||
|
||||
@ -12,6 +12,7 @@ from contextlib import contextmanager
|
||||
from itertools import product, chain
|
||||
import torch.jit.frontend
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.autograd.function import _nested_map
|
||||
from torch.onnx import OperatorExportTypes
|
||||
from torch._six import inf, PY2, builtins, StringIO
|
||||
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
@ -19,7 +20,7 @@ from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
||||
freeze_rng_state, set_rng_seed, slowTest
|
||||
from common_nn import module_tests, new_module_tests, criterion_tests
|
||||
from textwrap import dedent
|
||||
from functools import wraps
|
||||
from functools import wraps, reduce
|
||||
import os
|
||||
import io
|
||||
import sys
|
||||
@ -535,9 +536,24 @@ class JitTestCase(TestCase):
|
||||
if input_tensors is None:
|
||||
input_tensors = reference_tensors
|
||||
|
||||
def do_input_map(fn, input):
|
||||
return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
|
||||
|
||||
def flatten_inputs(inputs):
|
||||
def input_reduce(input, fn, acc):
|
||||
if isinstance(input, torch.Tensor):
|
||||
fn(input, acc)
|
||||
elif isinstance(input, dict):
|
||||
reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
|
||||
else:
|
||||
reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
|
||||
return acc
|
||||
return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
|
||||
|
||||
nograd_inputs = reference_tensors
|
||||
if inputs_require_grads:
|
||||
recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
|
||||
recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
|
||||
flattened_recording_inputs = flatten_inputs(recording_inputs)
|
||||
else:
|
||||
recording_inputs = reference_tensors
|
||||
|
||||
@ -558,12 +574,12 @@ class JitTestCase(TestCase):
|
||||
# test single grad case
|
||||
outputs = func(*recording_inputs)
|
||||
if inputs_require_grads:
|
||||
grads = torch.autograd.grad(allSum(outputs), recording_inputs,
|
||||
grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
|
||||
allow_unused=allow_unused)
|
||||
|
||||
outputs_ge = ge(*recording_inputs)
|
||||
if inputs_require_grads:
|
||||
grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
|
||||
grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
|
||||
allow_unused=allow_unused)
|
||||
self.assertEqual(outputs, outputs_ge)
|
||||
if inputs_require_grads:
|
||||
@ -574,25 +590,25 @@ class JitTestCase(TestCase):
|
||||
outputs = func(*recording_inputs)
|
||||
l1 = allSum(outputs)
|
||||
if inputs_require_grads:
|
||||
grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
|
||||
grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
|
||||
allow_unused=allow_unused)
|
||||
if inputs_require_grads:
|
||||
l2 = (allSum(grads) * l1)
|
||||
grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
|
||||
grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
|
||||
|
||||
if inputs_require_grads:
|
||||
recording_inputs = [Variable(t, requires_grad=True)
|
||||
for t in reference_tensors]
|
||||
recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
|
||||
flattened_recording_inputs = flatten_inputs(recording_inputs)
|
||||
|
||||
outputs_ge = ge(*recording_inputs)
|
||||
l1_ge = allSum(outputs_ge)
|
||||
if inputs_require_grads:
|
||||
grads_ge = torch.autograd.grad(
|
||||
l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
|
||||
l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
|
||||
|
||||
if inputs_require_grads:
|
||||
l2_ge = (allSum(grads_ge) * l1_ge)
|
||||
grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
|
||||
grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
|
||||
|
||||
self.assertEqual(outputs, outputs_ge)
|
||||
if inputs_require_grads:
|
||||
@ -1580,8 +1596,6 @@ graph(%x : Tensor,
|
||||
|
||||
self.checkTrace(fn, (torch.randn(2, 2),))
|
||||
|
||||
# TODO: implement
|
||||
@unittest.expectedFailure
|
||||
def test_input_flatten(self):
|
||||
"""Check that inputs to traced functions are flattened"""
|
||||
|
||||
@ -1592,6 +1606,66 @@ graph(%x : Tensor,
|
||||
inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
|
||||
self.checkTrace(fn, inputs)
|
||||
|
||||
def test_input_dict_empty(self):
|
||||
def test(d):
|
||||
pass
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.checkTrace(test, {})
|
||||
|
||||
def test_input_dict_flattens(self):
|
||||
class Test(torch.nn.Module):
|
||||
def forward(self, d):
|
||||
return d['x'] + d['y']
|
||||
|
||||
inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
|
||||
module = torch.jit.trace(Test(), inputs)
|
||||
FileCheck().check('aten::values').check('prim::ListUnpack').run(str(module.graph))
|
||||
|
||||
def test_input_dict_flattens_recursive(self):
|
||||
class Test(torch.nn.Module):
|
||||
def forward(self, d):
|
||||
# Use both to avoid getting optimized away
|
||||
a = d['x'][0]
|
||||
b, c = d['y']
|
||||
return a + b
|
||||
|
||||
inputs = {'x': (torch.rand(2, 2), torch.rand(2, 2)), 'y': (torch.ones(1, 1), torch.ones(2, 1))}
|
||||
module = torch.jit.trace(Test(), inputs)
|
||||
FileCheck().check('aten::values') \
|
||||
.check('prim::ListUnpack') \
|
||||
.check_count('prim::TupleUnpack', 2) \
|
||||
.run(str(module.graph))
|
||||
|
||||
def test_input_dict_checkTrace_mut(self):
|
||||
def test(d):
|
||||
d['x'].tanh_()
|
||||
return d['x']
|
||||
inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
|
||||
self.checkTrace(test, (inputs,), inputs_require_grads=False)
|
||||
|
||||
def test_input_dict_unify(self):
|
||||
def test(d):
|
||||
return d['int'], d['float']
|
||||
inputs = {'int': torch.ones((2, 2), dtype=torch.int32),
|
||||
'float': torch.ones((2, 2), dtype=torch.float32)}
|
||||
self.checkTrace(test, (inputs,), inputs_require_grads=False)
|
||||
|
||||
def test_input_tuple_of_dicts(self):
|
||||
def test(t):
|
||||
d = t[0]
|
||||
return d['x']['y']
|
||||
inputs = {'x': {'y': torch.rand(2, 3)}}
|
||||
self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
|
||||
|
||||
def test_input_dict_of_dicts(self):
|
||||
def test(d):
|
||||
return d['x']['y']
|
||||
nested_input = {'y': torch.rand(2, 3)}
|
||||
unified_nested = {'y': torch.rand(3, 2)}
|
||||
inputs = {'x': nested_input, 'force_unify': unified_nested}
|
||||
self.checkTrace(test, (inputs,), allow_unused=True)
|
||||
|
||||
# TODO: adapt to a GraphExecutor test
|
||||
@unittest.skip("Need to instrument GraphExecutors a bit more")
|
||||
def test_flags(self):
|
||||
|
||||
@ -255,6 +255,8 @@ def _nested_map(condition, fn, condition_msg=None):
|
||||
return None
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return type(obj)(_map(x) for x in obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {x : _map(obj[x]) for x in obj}
|
||||
else:
|
||||
raise ValueError("Auto nesting doesn't know how to process "
|
||||
"an input object of type " + torch.typename(obj) +
|
||||
@ -284,6 +286,11 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None,
|
||||
for o in obj:
|
||||
for var in _iter(o):
|
||||
yield var
|
||||
elif isinstance(obj, dict):
|
||||
# We only accept primitive key types, so we needn't inspect them
|
||||
for o in obj.values():
|
||||
for var in _iter(o):
|
||||
yield var
|
||||
elif allow_unknown:
|
||||
yield obj
|
||||
else:
|
||||
|
||||
@ -35,24 +35,85 @@ namespace jit {
|
||||
// that is confusing to display to the end user since it always reports
|
||||
// locations in libtorch code rather than user code.
|
||||
|
||||
inline IValue toIValue(py::handle input) {
|
||||
using tracer::TypedStack;
|
||||
struct TypedIValue : public std::pair<IValue, TypePtr> {
|
||||
using pair::pair;
|
||||
|
||||
IValue& ivalue() {
|
||||
return this->first;
|
||||
}
|
||||
TypePtr& type() {
|
||||
return this->second;
|
||||
}
|
||||
};
|
||||
|
||||
inline TypedIValue toDictKeyIValue(py::handle key) {
|
||||
if (py::isinstance<py::str>(key)) {
|
||||
return TypedIValue(ConstantString::create(py::cast<std::string>(key)),
|
||||
StringType::create());
|
||||
} else if (PyLong_Check(key.ptr())) {
|
||||
return TypedIValue(py::cast<int64_t>(key), IntType::create());
|
||||
} else if (PyFloat_Check(key.ptr())) {
|
||||
return TypedIValue(py::cast<double>(key), FloatType::create());
|
||||
} else {
|
||||
AT_ERROR("Dictionary inputs may only have string, int, or float keys");
|
||||
}
|
||||
}
|
||||
|
||||
inline TypedIValue toTypedIValue(py::handle input) {
|
||||
if (THPVariable_Check(input.ptr())) {
|
||||
auto ten = py::cast<at::Tensor>(input);
|
||||
if (ten.is_sparse()) {
|
||||
AT_ERROR("sparse tensors not supported");
|
||||
}
|
||||
return ten;
|
||||
return TypedIValue(ten, CompleteTensorType::create(ten));
|
||||
} else if (six::isTuple(input)) {
|
||||
py::tuple input_tuple = py::cast<py::tuple>(input);
|
||||
Stack s;
|
||||
std::vector<TypePtr> t;
|
||||
s.reserve(input_tuple.size());
|
||||
t.reserve(input_tuple.size());
|
||||
for (py::handle elem : input_tuple) {
|
||||
s.push_back(toIValue(elem));
|
||||
auto info = toTypedIValue(elem);
|
||||
s.push_back(info.first);
|
||||
t.push_back(info.second);
|
||||
}
|
||||
return Tuple::create(s);
|
||||
return TypedIValue(Tuple::create(s), TupleType::create(t));
|
||||
} else if (PyDict_Check(input.ptr())) {
|
||||
// Check to make sure we can generate useful input/output types
|
||||
auto dict = py::cast<py::dict>(input);
|
||||
at::ivalue::UnorderedMap elems;
|
||||
|
||||
size_t len = py::len(dict);
|
||||
if (!len) {
|
||||
AT_ERROR("Dictionary inputs must have entries.");
|
||||
}
|
||||
elems.reserve(len);
|
||||
|
||||
TypePtr keyType = nullptr;
|
||||
TypePtr valueType = nullptr;
|
||||
for (auto entry : dict) {
|
||||
auto keyInfo = toDictKeyIValue(entry.first);
|
||||
auto valInfo = toTypedIValue(entry.second);
|
||||
if (!keyType) {
|
||||
keyType = keyInfo.second;
|
||||
valueType = valInfo.second;
|
||||
} else {
|
||||
auto unifiedKey = unifyTypes(keyType, keyInfo.second);
|
||||
auto unifiedValue = unifyTypes(valueType, valInfo.second);
|
||||
if (!unifiedKey || !unifiedValue) {
|
||||
AT_ERROR("Dictionary inputs to traced functions must have consistent type");
|
||||
}
|
||||
keyType = *unifiedKey;
|
||||
valueType = *unifiedValue;
|
||||
}
|
||||
elems.insert(std::make_pair(keyInfo.first, valInfo.first));
|
||||
}
|
||||
return TypedIValue(at::ivalue::GenericDict::create(std::move(elems)),
|
||||
DictType::create(keyType, valueType));
|
||||
} else {
|
||||
throw std::runtime_error(c10::str(
|
||||
"Only tensors and (possibly nested) tuples of tensors are supported ",
|
||||
"Only tensors and (possibly nested) tuples of tensors or dicts are supported ",
|
||||
"as inputs or outputs of traced functions",
|
||||
", but instead got value of type ",
|
||||
py::str(input.get_type().attr("__name__")),
|
||||
@ -62,10 +123,19 @@ inline IValue toIValue(py::handle input) {
|
||||
}
|
||||
}
|
||||
|
||||
inline IValue toIValue(py::handle input) {
|
||||
return toTypedIValue(input).ivalue();
|
||||
}
|
||||
|
||||
inline Stack toStack(const py::tuple& inputs) {
|
||||
return toIValue(inputs).toTuple()->elements();
|
||||
}
|
||||
|
||||
inline TypedStack toTypedStack(const py::tuple& inputs) {
|
||||
auto info = toTypedIValue(inputs);
|
||||
return TypedStack(info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
|
||||
}
|
||||
|
||||
inline IValue toIValue(
|
||||
py::handle obj,
|
||||
const TypePtr& type,
|
||||
|
||||
@ -38,7 +38,7 @@ std::string getPythonInterpreterStackTrace() {
|
||||
|
||||
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
|
||||
const py::function& func,
|
||||
Stack trace_inputs,
|
||||
TypedStack trace_inputs,
|
||||
const py::function& var_name_lookup_fn,
|
||||
bool force_outplace,
|
||||
const c10::optional<size_t>& num_real_inputs) {
|
||||
@ -145,7 +145,7 @@ void initPythonTracerBindings(PyObject* module) {
|
||||
|
||||
m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
|
||||
m.def("_tracer_enter", [](py::args trace_inputs) {
|
||||
return tracer::enter(toStack(trace_inputs));
|
||||
return tracer::enter(toTypedStack(trace_inputs));
|
||||
});
|
||||
m.def("_tracer_exit", [](py::tuple var_outputs) {
|
||||
tracer::exit(toStack(var_outputs));
|
||||
|
||||
@ -21,7 +21,7 @@ Node* preRecordPythonTrace(
|
||||
|
||||
std::shared_ptr<Graph> createGraphByTracing(
|
||||
const py::function& func,
|
||||
Stack inputs,
|
||||
TypedStack inputs,
|
||||
const py::function& var_name_lookup_fn,
|
||||
bool force_outplace,
|
||||
const c10::optional<size_t>& num_real_inputs = c10::nullopt);
|
||||
|
||||
@ -1540,15 +1540,33 @@ int dictKeys(Stack& stack) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dictValues(Stack& stack) {
|
||||
auto dict = pop(stack).toGenericDictRef();
|
||||
std::vector<IValue> values;
|
||||
template <typename Elem>
|
||||
std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) {
|
||||
std::vector<Elem> values;
|
||||
values.reserve(dict.size());
|
||||
for (auto item : dict) {
|
||||
values.push_back(item.second);
|
||||
values.push_back(item.second.to<Elem>());
|
||||
}
|
||||
push(stack, IValue(values));
|
||||
return 0;
|
||||
return values;
|
||||
}
|
||||
|
||||
Operation dictValues(const Node* n) {
|
||||
auto outputType = n->output()->type()->expect<ListType>();
|
||||
return [=](Stack& stack) -> int {
|
||||
auto dict = pop(stack).toGenericDictRef();
|
||||
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
|
||||
push(stack, makeListForDictValues<at::Tensor>(dict));
|
||||
} else if (outputType->getElementType() == IntType::get()) {
|
||||
push(stack, makeListForDictValues<int64_t>(dict));
|
||||
} else if (outputType->getElementType() == FloatType::get()) {
|
||||
push(stack, makeListForDictValues<double>(dict));
|
||||
} else if (outputType->getElementType() == BoolType::get()) {
|
||||
push(stack, makeListForDictValues<bool>(dict));
|
||||
} else {
|
||||
push(stack, makeListForDictValues<IValue>(dict));
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
int dictIndex(Stack& stack) {
|
||||
|
||||
@ -939,13 +939,19 @@ void initJitScriptBindings(PyObject* module) {
|
||||
// this was ensured in python before calling this function
|
||||
std::vector<Slot> parameters;
|
||||
gatherParametersAndBuffers(parameters, *self);
|
||||
Stack inputs = toStack(input_tuple);
|
||||
for (const Slot& param : parameters) {
|
||||
inputs.emplace_back(param.value());
|
||||
auto typed_inputs = toTypedStack(input_tuple);
|
||||
if (parameters.size() > 0) {
|
||||
auto inputs = typed_inputs.stack();
|
||||
auto input_types = typed_inputs.types()->elements().vec();
|
||||
for (const Slot& param : parameters) {
|
||||
inputs.emplace_back(param.value());
|
||||
input_types.push_back(incompleteInferTypeFrom(param.value()));
|
||||
}
|
||||
typed_inputs = TypedStack(inputs, TupleType::create(input_types));
|
||||
}
|
||||
auto graph = tracer::createGraphByTracing(
|
||||
func,
|
||||
inputs,
|
||||
typed_inputs,
|
||||
var_lookup_fn,
|
||||
force_outplace,
|
||||
input_tuple.size());
|
||||
|
||||
@ -201,7 +201,7 @@ Value* getNestedOutputTrace(
|
||||
// Start tracing, treating 'inputs' as inputs to the trace, which can be
|
||||
// varied on subsequent invocations of the trace. Any other variables
|
||||
// will be treated as constants.
|
||||
std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
|
||||
std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) {
|
||||
if (isTracing()) {
|
||||
AT_ERROR("Tracing can't be nested");
|
||||
}
|
||||
@ -234,17 +234,34 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
|
||||
elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
|
||||
}
|
||||
return Tuple::create(std::move(elems));
|
||||
} else if (auto dict_type = type->cast<DictType>()) {
|
||||
auto elem_pairs = input.toGenericDict()->elements();
|
||||
auto unpack_to_list = state->graph->insert(aten::values, {value});
|
||||
auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size());
|
||||
auto unpack_node = state->graph->insertNode(list_unpack);
|
||||
auto elem_values = unpack_node->outputs();
|
||||
|
||||
AT_ASSERT(elem_pairs.size() == elem_values.size());
|
||||
|
||||
size_t i = 0;
|
||||
for (const auto &pair : elem_pairs) {
|
||||
elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]);
|
||||
}
|
||||
|
||||
return c10::ivalue::GenericDict::create(std::move(elem_pairs));
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Only tensors or tuples of tensors can be inputs to traced functions. Got ",
|
||||
type);
|
||||
"Only tensors or (possibly nested) dict or tuples of tensors can be "
|
||||
"inputs to traced functions. Got ", type);
|
||||
}
|
||||
};
|
||||
for (IValue& input : inputs) {
|
||||
size_t i = 0;
|
||||
auto input_types = inputs.types()->elements();
|
||||
for (IValue& input : inputs.stack()) {
|
||||
input = add_input(
|
||||
input, incompleteInferTypeFrom(input), state->graph->addInput());
|
||||
input, input_types[i++], state->graph->addInput());
|
||||
}
|
||||
return std::make_pair(state, inputs);
|
||||
return std::make_pair(state, inputs.stack());
|
||||
}
|
||||
|
||||
// Exit a trace, treating 'outputs' as the outputs of the trace. These
|
||||
|
||||
@ -65,7 +65,30 @@ TORCH_API Value* getNestedOutputTrace(
|
||||
const std::shared_ptr<TracingState>& state,
|
||||
const IValue& iv);
|
||||
|
||||
TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs);
|
||||
struct TypedStack : public std::pair<Stack, TupleTypePtr>
|
||||
{
|
||||
using pair::pair;
|
||||
|
||||
// NB: The inherited default constructor gives nullptr for |type|,
|
||||
// so we provide a saner one.
|
||||
TypedStack()
|
||||
: pair({}, TupleType::create({}))
|
||||
{}
|
||||
|
||||
Stack& stack() {
|
||||
return this->first;
|
||||
}
|
||||
TupleTypePtr& types() {
|
||||
return this->second;
|
||||
}
|
||||
size_t size() {
|
||||
auto s = stack().size();
|
||||
AT_ASSERT(s == types()->elements().size());
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs);
|
||||
|
||||
TORCH_API void exit(const Stack& outputs);
|
||||
|
||||
|
||||
@ -656,7 +656,7 @@ def trace(func,
|
||||
return func
|
||||
executor_options = {'optimize': bool(optimize)}
|
||||
# Special case for common case of passing a single Tensor
|
||||
if isinstance(example_inputs, torch.Tensor):
|
||||
if isinstance(example_inputs, (torch.Tensor, dict)):
|
||||
example_inputs = (example_inputs,)
|
||||
# done primarily so that weird iterables fail here and not pybind11 code
|
||||
elif not isinstance(example_inputs, tuple):
|
||||
|
||||
Reference in New Issue
Block a user