mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Allow passing dicts as trace inputs. (#18092)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18092 Previously, tracing required all inputs to be either tensors, or tuples of tensor. Now, we allow users to pass dicts as well. Differential Revision: D14491795 fbshipit-source-id: 7a2df218e5d00f898d01fa5b9669f9d674280be3
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							9034b66f14
						
					
				
				
					commit
					593bb145ce
				
			| @ -287,6 +287,18 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) { | ||||
|       } | ||||
|     } | ||||
|     return static_cast<TypePtr>(TupleType::create(elements)); | ||||
|   } else if (t1->cast<DictType>() && t2->cast<DictType>()) { | ||||
|     auto dict1 = t1->cast<DictType>(); | ||||
|     auto dict2 = t2->cast<DictType>(); | ||||
|  | ||||
|     auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType()); | ||||
|     auto unshaped_value1 = unshapedType(dict1->getValueType()); | ||||
|     auto unshaped_value2 = unshapedType(dict2->getValueType()); | ||||
|     auto unified_value = tryEitherIsTheSuperType(unshaped_value1, unshaped_value2); | ||||
|     if (!unified_key || !unified_value) { | ||||
|       return c10::nullopt; | ||||
|     } | ||||
|     return DictType::create(*unified_key, *unified_value); | ||||
|   } | ||||
|  | ||||
|   return c10::nullopt; | ||||
|  | ||||
| @ -60,9 +60,17 @@ variable_list get_grad_outputs(const variable_list& vars) { | ||||
| std::shared_ptr<Graph> trace( | ||||
|     const ADTestSpec& test, | ||||
|     const variable_list& vars_in) { | ||||
|   Stack input_vars = fmap<IValue>(vars_in); | ||||
|   std::vector<TypePtr> input_types; | ||||
|   input_types.reserve(input_vars.size()); | ||||
|   for (auto i = 0; i < input_vars.size(); i++) { | ||||
|     input_types.push_back(TensorType::get()); | ||||
|   } | ||||
|   auto input_typeptr = TupleType::create(std::move(input_types)); | ||||
|   std::shared_ptr<tracer::TracingState> state; | ||||
|   Stack trace_stack_in; | ||||
|   std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in)); | ||||
|   std::tie(state, trace_stack_in) = | ||||
|       tracer::enter(tracer::TypedStack(input_vars, input_typeptr)); | ||||
|   variable_list trace_vars_in = fmap( | ||||
|       trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); }); | ||||
|   auto trace_vars_out = test(trace_vars_in); | ||||
|  | ||||
| @ -12,6 +12,7 @@ from contextlib import contextmanager | ||||
| from itertools import product, chain | ||||
| import torch.jit.frontend | ||||
| from torch.autograd import Variable, Function | ||||
| from torch.autograd.function import _nested_map | ||||
| from torch.onnx import OperatorExportTypes | ||||
| from torch._six import inf, PY2, builtins, StringIO | ||||
| from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ | ||||
| @ -19,7 +20,7 @@ from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ | ||||
|     freeze_rng_state, set_rng_seed, slowTest | ||||
| from common_nn import module_tests, new_module_tests, criterion_tests | ||||
| from textwrap import dedent | ||||
| from functools import wraps | ||||
| from functools import wraps, reduce | ||||
| import os | ||||
| import io | ||||
| import sys | ||||
| @ -535,9 +536,24 @@ class JitTestCase(TestCase): | ||||
|         if input_tensors is None: | ||||
|             input_tensors = reference_tensors | ||||
|  | ||||
|         def do_input_map(fn, input): | ||||
|             return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input) | ||||
|  | ||||
|         def flatten_inputs(inputs): | ||||
|             def input_reduce(input, fn, acc): | ||||
|                 if isinstance(input, torch.Tensor): | ||||
|                     fn(input, acc) | ||||
|                 elif isinstance(input, dict): | ||||
|                     reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc) | ||||
|                 else: | ||||
|                     reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc) | ||||
|                 return acc | ||||
|             return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), [])) | ||||
|  | ||||
|         nograd_inputs = reference_tensors | ||||
|         if inputs_require_grads: | ||||
|             recording_inputs = [t.clone().requires_grad_() for t in reference_tensors] | ||||
|             recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors) | ||||
|             flattened_recording_inputs = flatten_inputs(recording_inputs) | ||||
|         else: | ||||
|             recording_inputs = reference_tensors | ||||
|  | ||||
| @ -558,12 +574,12 @@ class JitTestCase(TestCase): | ||||
|         # test single grad case | ||||
|         outputs = func(*recording_inputs) | ||||
|         if inputs_require_grads: | ||||
|             grads = torch.autograd.grad(allSum(outputs), recording_inputs, | ||||
|             grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs, | ||||
|                                         allow_unused=allow_unused) | ||||
|  | ||||
|         outputs_ge = ge(*recording_inputs) | ||||
|         if inputs_require_grads: | ||||
|             grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs, | ||||
|             grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs, | ||||
|                                            allow_unused=allow_unused) | ||||
|         self.assertEqual(outputs, outputs_ge) | ||||
|         if inputs_require_grads: | ||||
| @ -574,25 +590,25 @@ class JitTestCase(TestCase): | ||||
|         outputs = func(*recording_inputs) | ||||
|         l1 = allSum(outputs) | ||||
|         if inputs_require_grads: | ||||
|             grads = torch.autograd.grad(l1, recording_inputs, create_graph=True, | ||||
|             grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True, | ||||
|                                         allow_unused=allow_unused) | ||||
|         if inputs_require_grads: | ||||
|             l2 = (allSum(grads) * l1) | ||||
|             grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused) | ||||
|             grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused) | ||||
|  | ||||
|         if inputs_require_grads: | ||||
|             recording_inputs = [Variable(t, requires_grad=True) | ||||
|                                 for t in reference_tensors] | ||||
|             recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors) | ||||
|             flattened_recording_inputs = flatten_inputs(recording_inputs) | ||||
|  | ||||
|         outputs_ge = ge(*recording_inputs) | ||||
|         l1_ge = allSum(outputs_ge) | ||||
|         if inputs_require_grads: | ||||
|             grads_ge = torch.autograd.grad( | ||||
|                 l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused) | ||||
|                 l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused) | ||||
|  | ||||
|         if inputs_require_grads: | ||||
|             l2_ge = (allSum(grads_ge) * l1_ge) | ||||
|             grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused) | ||||
|             grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused) | ||||
|  | ||||
|         self.assertEqual(outputs, outputs_ge) | ||||
|         if inputs_require_grads: | ||||
| @ -1580,8 +1596,6 @@ graph(%x : Tensor, | ||||
|  | ||||
|         self.checkTrace(fn, (torch.randn(2, 2),)) | ||||
|  | ||||
|     # TODO: implement | ||||
|     @unittest.expectedFailure | ||||
|     def test_input_flatten(self): | ||||
|         """Check that inputs to traced functions are flattened""" | ||||
|  | ||||
| @ -1592,6 +1606,66 @@ graph(%x : Tensor, | ||||
|         inputs = (torch.randn(1), (torch.randn(1), torch.randn(1))) | ||||
|         self.checkTrace(fn, inputs) | ||||
|  | ||||
|     def test_input_dict_empty(self): | ||||
|         def test(d): | ||||
|             pass | ||||
|  | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             self.checkTrace(test, {}) | ||||
|  | ||||
|     def test_input_dict_flattens(self): | ||||
|         class Test(torch.nn.Module): | ||||
|             def forward(self, d): | ||||
|                 return d['x'] + d['y'] | ||||
|  | ||||
|         inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)} | ||||
|         module = torch.jit.trace(Test(), inputs) | ||||
|         FileCheck().check('aten::values').check('prim::ListUnpack').run(str(module.graph)) | ||||
|  | ||||
|     def test_input_dict_flattens_recursive(self): | ||||
|         class Test(torch.nn.Module): | ||||
|             def forward(self, d): | ||||
|                 # Use both to avoid getting optimized away | ||||
|                 a = d['x'][0] | ||||
|                 b, c = d['y'] | ||||
|                 return a + b | ||||
|  | ||||
|         inputs = {'x': (torch.rand(2, 2), torch.rand(2, 2)), 'y': (torch.ones(1, 1), torch.ones(2, 1))} | ||||
|         module = torch.jit.trace(Test(), inputs) | ||||
|         FileCheck().check('aten::values') \ | ||||
|                    .check('prim::ListUnpack') \ | ||||
|                    .check_count('prim::TupleUnpack', 2) \ | ||||
|                    .run(str(module.graph)) | ||||
|  | ||||
|     def test_input_dict_checkTrace_mut(self): | ||||
|         def test(d): | ||||
|             d['x'].tanh_() | ||||
|             return d['x'] | ||||
|         inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)} | ||||
|         self.checkTrace(test, (inputs,), inputs_require_grads=False) | ||||
|  | ||||
|     def test_input_dict_unify(self): | ||||
|         def test(d): | ||||
|             return d['int'], d['float'] | ||||
|         inputs = {'int': torch.ones((2, 2), dtype=torch.int32), | ||||
|                   'float': torch.ones((2, 2), dtype=torch.float32)} | ||||
|         self.checkTrace(test, (inputs,), inputs_require_grads=False) | ||||
|  | ||||
|     def test_input_tuple_of_dicts(self): | ||||
|         def test(t): | ||||
|             d = t[0] | ||||
|             return d['x']['y'] | ||||
|         inputs = {'x': {'y': torch.rand(2, 3)}} | ||||
|         self.checkTrace(test, ((inputs, inputs),), allow_unused=True) | ||||
|  | ||||
|     def test_input_dict_of_dicts(self): | ||||
|         def test(d): | ||||
|             return d['x']['y'] | ||||
|         nested_input = {'y': torch.rand(2, 3)} | ||||
|         unified_nested = {'y': torch.rand(3, 2)} | ||||
|         inputs = {'x': nested_input, 'force_unify': unified_nested} | ||||
|         self.checkTrace(test, (inputs,), allow_unused=True) | ||||
|  | ||||
|     # TODO: adapt to a GraphExecutor test | ||||
|     @unittest.skip("Need to instrument GraphExecutors a bit more") | ||||
|     def test_flags(self): | ||||
|  | ||||
| @ -255,6 +255,8 @@ def _nested_map(condition, fn, condition_msg=None): | ||||
|             return None | ||||
|         elif isinstance(obj, (list, tuple)): | ||||
|             return type(obj)(_map(x) for x in obj) | ||||
|         elif isinstance(obj, dict): | ||||
|             return {x : _map(obj[x]) for x in obj} | ||||
|         else: | ||||
|             raise ValueError("Auto nesting doesn't know how to process " | ||||
|                              "an input object of type " + torch.typename(obj) + | ||||
| @ -284,6 +286,11 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None, | ||||
|             for o in obj: | ||||
|                 for var in _iter(o): | ||||
|                     yield var | ||||
|         elif isinstance(obj, dict): | ||||
|             # We only accept primitive key types, so we needn't inspect them | ||||
|             for o in obj.values(): | ||||
|                 for var in _iter(o): | ||||
|                     yield var | ||||
|         elif allow_unknown: | ||||
|             yield obj | ||||
|         else: | ||||
|  | ||||
| @ -35,24 +35,85 @@ namespace jit { | ||||
| // that is confusing to display to the end user since it always reports | ||||
| // locations in libtorch code rather than user code. | ||||
|  | ||||
| inline IValue toIValue(py::handle input) { | ||||
| using tracer::TypedStack; | ||||
| struct TypedIValue : public std::pair<IValue, TypePtr> { | ||||
|   using pair::pair; | ||||
|  | ||||
|   IValue& ivalue() { | ||||
|     return this->first; | ||||
|   } | ||||
|   TypePtr& type() { | ||||
|     return this->second; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| inline TypedIValue toDictKeyIValue(py::handle key) { | ||||
|   if (py::isinstance<py::str>(key)) { | ||||
|     return TypedIValue(ConstantString::create(py::cast<std::string>(key)), | ||||
|                        StringType::create()); | ||||
|   } else if (PyLong_Check(key.ptr())) { | ||||
|     return TypedIValue(py::cast<int64_t>(key), IntType::create()); | ||||
|   } else if (PyFloat_Check(key.ptr())) { | ||||
|     return TypedIValue(py::cast<double>(key), FloatType::create()); | ||||
|   } else { | ||||
|     AT_ERROR("Dictionary inputs may only have string, int, or float keys"); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline TypedIValue toTypedIValue(py::handle input) { | ||||
|   if (THPVariable_Check(input.ptr())) { | ||||
|     auto ten = py::cast<at::Tensor>(input); | ||||
|     if (ten.is_sparse()) { | ||||
|       AT_ERROR("sparse tensors not supported"); | ||||
|     } | ||||
|     return ten; | ||||
|     return TypedIValue(ten, CompleteTensorType::create(ten)); | ||||
|   } else if (six::isTuple(input)) { | ||||
|     py::tuple input_tuple = py::cast<py::tuple>(input); | ||||
|     Stack s; | ||||
|     std::vector<TypePtr> t; | ||||
|     s.reserve(input_tuple.size()); | ||||
|     t.reserve(input_tuple.size()); | ||||
|     for (py::handle elem : input_tuple) { | ||||
|       s.push_back(toIValue(elem)); | ||||
|       auto info = toTypedIValue(elem); | ||||
|       s.push_back(info.first); | ||||
|       t.push_back(info.second); | ||||
|     } | ||||
|     return Tuple::create(s); | ||||
|     return TypedIValue(Tuple::create(s), TupleType::create(t)); | ||||
|   } else if (PyDict_Check(input.ptr())) { | ||||
|     // Check to make sure we can generate useful input/output types | ||||
|     auto dict = py::cast<py::dict>(input); | ||||
|     at::ivalue::UnorderedMap elems; | ||||
|  | ||||
|     size_t len = py::len(dict); | ||||
|     if (!len) { | ||||
|       AT_ERROR("Dictionary inputs must have entries."); | ||||
|     } | ||||
|     elems.reserve(len); | ||||
|  | ||||
|     TypePtr keyType = nullptr; | ||||
|     TypePtr valueType = nullptr; | ||||
|     for (auto entry : dict) { | ||||
|       auto keyInfo = toDictKeyIValue(entry.first); | ||||
|       auto valInfo = toTypedIValue(entry.second); | ||||
|       if (!keyType) { | ||||
|         keyType = keyInfo.second; | ||||
|         valueType = valInfo.second; | ||||
|       } else { | ||||
|         auto unifiedKey = unifyTypes(keyType, keyInfo.second); | ||||
|         auto unifiedValue = unifyTypes(valueType, valInfo.second); | ||||
|         if (!unifiedKey || !unifiedValue) { | ||||
|           AT_ERROR("Dictionary inputs to traced functions must have consistent type"); | ||||
|         } | ||||
|         keyType = *unifiedKey; | ||||
|         valueType = *unifiedValue; | ||||
|       } | ||||
|       elems.insert(std::make_pair(keyInfo.first, valInfo.first)); | ||||
|     } | ||||
|     return TypedIValue(at::ivalue::GenericDict::create(std::move(elems)), | ||||
|                        DictType::create(keyType, valueType)); | ||||
|   } else { | ||||
|     throw std::runtime_error(c10::str( | ||||
|         "Only tensors and (possibly nested) tuples of tensors are supported ", | ||||
|         "Only tensors and (possibly nested) tuples of tensors or dicts are supported ", | ||||
|         "as inputs or outputs of traced functions", | ||||
|         ", but instead got value of type ", | ||||
|         py::str(input.get_type().attr("__name__")), | ||||
| @ -62,10 +123,19 @@ inline IValue toIValue(py::handle input) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline IValue toIValue(py::handle input) { | ||||
|     return toTypedIValue(input).ivalue(); | ||||
| } | ||||
|  | ||||
| inline Stack toStack(const py::tuple& inputs) { | ||||
|   return toIValue(inputs).toTuple()->elements(); | ||||
| } | ||||
|  | ||||
| inline TypedStack toTypedStack(const py::tuple& inputs) { | ||||
|   auto info = toTypedIValue(inputs); | ||||
|   return TypedStack(info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>()); | ||||
| } | ||||
|  | ||||
| inline IValue toIValue( | ||||
|     py::handle obj, | ||||
|     const TypePtr& type, | ||||
|  | ||||
| @ -38,7 +38,7 @@ std::string getPythonInterpreterStackTrace() { | ||||
|  | ||||
| std::shared_ptr<torch::jit::Graph> createGraphByTracing( | ||||
|     const py::function& func, | ||||
|     Stack trace_inputs, | ||||
|     TypedStack trace_inputs, | ||||
|     const py::function& var_name_lookup_fn, | ||||
|     bool force_outplace, | ||||
|     const c10::optional<size_t>& num_real_inputs) { | ||||
| @ -145,7 +145,7 @@ void initPythonTracerBindings(PyObject* module) { | ||||
|  | ||||
|   m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); }); | ||||
|   m.def("_tracer_enter", [](py::args trace_inputs) { | ||||
|     return tracer::enter(toStack(trace_inputs)); | ||||
|     return tracer::enter(toTypedStack(trace_inputs)); | ||||
|   }); | ||||
|   m.def("_tracer_exit", [](py::tuple var_outputs) { | ||||
|     tracer::exit(toStack(var_outputs)); | ||||
|  | ||||
| @ -21,7 +21,7 @@ Node* preRecordPythonTrace( | ||||
|  | ||||
| std::shared_ptr<Graph> createGraphByTracing( | ||||
|     const py::function& func, | ||||
|     Stack inputs, | ||||
|     TypedStack inputs, | ||||
|     const py::function& var_name_lookup_fn, | ||||
|     bool force_outplace, | ||||
|     const c10::optional<size_t>& num_real_inputs = c10::nullopt); | ||||
|  | ||||
| @ -1540,15 +1540,33 @@ int dictKeys(Stack& stack) { | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| int dictValues(Stack& stack) { | ||||
|   auto dict = pop(stack).toGenericDictRef(); | ||||
|   std::vector<IValue> values; | ||||
| template <typename Elem> | ||||
| std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) { | ||||
|   std::vector<Elem> values; | ||||
|   values.reserve(dict.size()); | ||||
|   for (auto item : dict) { | ||||
|     values.push_back(item.second); | ||||
|     values.push_back(item.second.to<Elem>()); | ||||
|   } | ||||
|   push(stack, IValue(values)); | ||||
|   return 0; | ||||
|   return values; | ||||
| } | ||||
|  | ||||
| Operation dictValues(const Node* n) { | ||||
|   auto outputType = n->output()->type()->expect<ListType>(); | ||||
|   return [=](Stack& stack) -> int { | ||||
|     auto dict = pop(stack).toGenericDictRef(); | ||||
|     if (outputType->getElementType()->isSubtypeOf(TensorType::get())) { | ||||
|       push(stack, makeListForDictValues<at::Tensor>(dict)); | ||||
|     } else if (outputType->getElementType() == IntType::get()) { | ||||
|       push(stack, makeListForDictValues<int64_t>(dict)); | ||||
|     } else if (outputType->getElementType() == FloatType::get()) { | ||||
|       push(stack, makeListForDictValues<double>(dict)); | ||||
|     } else if (outputType->getElementType() == BoolType::get()) { | ||||
|       push(stack, makeListForDictValues<bool>(dict)); | ||||
|     } else { | ||||
|       push(stack, makeListForDictValues<IValue>(dict)); | ||||
|     } | ||||
|     return 0; | ||||
|   }; | ||||
| } | ||||
|  | ||||
| int dictIndex(Stack& stack) { | ||||
|  | ||||
| @ -939,13 +939,19 @@ void initJitScriptBindings(PyObject* module) { | ||||
|             // this was ensured in python before calling this function | ||||
|             std::vector<Slot> parameters; | ||||
|             gatherParametersAndBuffers(parameters, *self); | ||||
|             Stack inputs = toStack(input_tuple); | ||||
|             for (const Slot& param : parameters) { | ||||
|               inputs.emplace_back(param.value()); | ||||
|             auto typed_inputs = toTypedStack(input_tuple); | ||||
|             if (parameters.size() > 0) { | ||||
|               auto inputs = typed_inputs.stack(); | ||||
|               auto input_types = typed_inputs.types()->elements().vec(); | ||||
|               for (const Slot& param : parameters) { | ||||
|                 inputs.emplace_back(param.value()); | ||||
|                 input_types.push_back(incompleteInferTypeFrom(param.value())); | ||||
|               } | ||||
|               typed_inputs = TypedStack(inputs, TupleType::create(input_types)); | ||||
|             } | ||||
|             auto graph = tracer::createGraphByTracing( | ||||
|                 func, | ||||
|                 inputs, | ||||
|                 typed_inputs, | ||||
|                 var_lookup_fn, | ||||
|                 force_outplace, | ||||
|                 input_tuple.size()); | ||||
|  | ||||
| @ -201,7 +201,7 @@ Value* getNestedOutputTrace( | ||||
| // Start tracing, treating 'inputs' as inputs to the trace, which can be | ||||
| // varied on subsequent invocations of the trace.  Any other variables | ||||
| // will be treated as constants. | ||||
| std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) { | ||||
| std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) { | ||||
|   if (isTracing()) { | ||||
|     AT_ERROR("Tracing can't be nested"); | ||||
|   } | ||||
| @ -234,17 +234,34 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) { | ||||
|         elems[i] = add_input(elems[i], elem_types[i], elem_values[i]); | ||||
|       } | ||||
|       return Tuple::create(std::move(elems)); | ||||
|     } else if (auto dict_type = type->cast<DictType>()) { | ||||
|       auto elem_pairs = input.toGenericDict()->elements(); | ||||
|       auto unpack_to_list = state->graph->insert(aten::values, {value}); | ||||
|       auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size()); | ||||
|       auto unpack_node = state->graph->insertNode(list_unpack); | ||||
|       auto elem_values = unpack_node->outputs(); | ||||
|  | ||||
|       AT_ASSERT(elem_pairs.size() == elem_values.size()); | ||||
|  | ||||
|       size_t i = 0; | ||||
|       for (const auto &pair : elem_pairs) { | ||||
|         elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]); | ||||
|       } | ||||
|  | ||||
|       return c10::ivalue::GenericDict::create(std::move(elem_pairs)); | ||||
|     } else { | ||||
|       AT_ERROR( | ||||
|           "Only tensors or tuples of tensors can be inputs to traced functions. Got ", | ||||
|           type); | ||||
|           "Only tensors or (possibly nested) dict or tuples of tensors can be " | ||||
|           "inputs to traced functions. Got ", type); | ||||
|     } | ||||
|   }; | ||||
|   for (IValue& input : inputs) { | ||||
|   size_t i = 0; | ||||
|   auto input_types = inputs.types()->elements(); | ||||
|   for (IValue& input : inputs.stack()) { | ||||
|     input = add_input( | ||||
|         input, incompleteInferTypeFrom(input), state->graph->addInput()); | ||||
|         input, input_types[i++], state->graph->addInput()); | ||||
|   } | ||||
|   return std::make_pair(state, inputs); | ||||
|   return std::make_pair(state, inputs.stack()); | ||||
| } | ||||
|  | ||||
| // Exit a trace, treating 'outputs' as the outputs of the trace.  These | ||||
|  | ||||
| @ -65,7 +65,30 @@ TORCH_API Value* getNestedOutputTrace( | ||||
|     const std::shared_ptr<TracingState>& state, | ||||
|     const IValue& iv); | ||||
|  | ||||
| TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs); | ||||
| struct TypedStack : public std::pair<Stack, TupleTypePtr> | ||||
| { | ||||
|   using pair::pair; | ||||
|  | ||||
|   // NB: The inherited default constructor gives nullptr for |type|, | ||||
|   //     so we provide a saner one. | ||||
|   TypedStack() | ||||
|     : pair({}, TupleType::create({})) | ||||
|   {} | ||||
|  | ||||
|   Stack& stack() { | ||||
|     return this->first; | ||||
|   } | ||||
|   TupleTypePtr& types() { | ||||
|     return this->second; | ||||
|   } | ||||
|   size_t size() { | ||||
|     auto s = stack().size(); | ||||
|     AT_ASSERT(s == types()->elements().size()); | ||||
|     return s; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs); | ||||
|  | ||||
| TORCH_API void exit(const Stack& outputs); | ||||
|  | ||||
|  | ||||
| @ -656,7 +656,7 @@ def trace(func, | ||||
|         return func | ||||
|     executor_options = {'optimize': bool(optimize)} | ||||
|     # Special case for common case of passing a single Tensor | ||||
|     if isinstance(example_inputs, torch.Tensor): | ||||
|     if isinstance(example_inputs, (torch.Tensor, dict)): | ||||
|         example_inputs = (example_inputs,) | ||||
|     # done primarily so that weird iterables fail here and not pybind11 code | ||||
|     elif not isinstance(example_inputs, tuple): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user