mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-01 04:54:55 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			ciflow/tru
			...
			fca
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| f85a0b82eb | 
| @ -37,6 +37,16 @@ struct TORCH_API TensorGeometry { | ||||
|         has_symbolic_sizes_strides_( | ||||
|             t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} | ||||
|  | ||||
|   explicit TensorGeometry( | ||||
|       std::vector<at::SymInt> sizes, | ||||
|       std::vector<at::SymInt> strides, | ||||
|       at::SymInt storage_offset) | ||||
|       : sizes_(std::move(sizes)), | ||||
|         strides_(std::move(strides)), | ||||
|         storage_offset_(std::move(storage_offset)) { | ||||
|     recompute(); | ||||
|   } | ||||
|  | ||||
|   // true if the tensor is contiguous | ||||
|   bool is_contiguous() const; | ||||
|  | ||||
|  | ||||
| @ -93,7 +93,9 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) { | ||||
|       case Tag::None: | ||||
|         return NoneType::get(); | ||||
|       case Tag::Tensor: | ||||
|         return TensorType::create(v.toTensor()); | ||||
|         return TensorType::get(); | ||||
|         // TODO(rzou): following errors | ||||
|         // return TensorType::create(v.toTensor()); | ||||
|       case Tag::Storage: | ||||
|         return StorageType::get(); | ||||
|       case Tag::Double: | ||||
|  | ||||
| @ -2075,6 +2075,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) { | ||||
|             cpp_sources=cpp_source, | ||||
|             functions="custom_op_backed_by_autograd_fn", | ||||
|             verbose=True, | ||||
|             extra_cflags=["-g", "-O0"], | ||||
|         ) | ||||
|  | ||||
|         def same_autograd_fn(): | ||||
| @ -2113,8 +2114,8 @@ TORCH_LIBRARY(test_autograd_cpp_node_id, m) { | ||||
|  | ||||
|         self.check_output_and_recompiles(different_autograd_fn, 2) | ||||
|  | ||||
|     @scoped_load_inline | ||||
|     def test_autograd_cpp_node_saved(self, load_inline): | ||||
|     @unittest.skip("Flaky, cache from test ordering affects test. #135369") | ||||
|     def test_autograd_cpp_node_saved(self): | ||||
|         cpp_source = """ | ||||
| struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { | ||||
|   static constexpr bool is_traceable = true; | ||||
| @ -2190,7 +2191,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved, m) { | ||||
|         self.check_output_and_recompiles(fn, 2) | ||||
|  | ||||
|     @scoped_load_inline | ||||
|     def test_autograd_cpp_node_saved_dynamic(self, load_inline): | ||||
|     def test_autograd_cpp_node_saved_dynamic(self): | ||||
|         cpp_source = """ | ||||
| struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { | ||||
|   static constexpr bool is_traceable = true; | ||||
|  | ||||
| @ -64,6 +64,9 @@ struct TORCH_API ${op} : public ${superclass} { | ||||
|   } | ||||
|   ${will_release_variables} | ||||
|   void compiled_args(CompiledNodeArgs& args) override; | ||||
|   ivalue_list get_state(); | ||||
|   ivalue_list retrieve_saved(SwapSavedVariables& saved) override; | ||||
|   functional_apply_t get_functional() override; | ||||
|   variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; | ||||
|   ${saved_variables} | ||||
|   ${saved_list_sizes} | ||||
| @ -82,15 +85,22 @@ void will_release_variables() override { | ||||
|  | ||||
| FUNCTION_DEFINITION = CodeTemplate( | ||||
|     """\ | ||||
| variable_list ${op}::apply(variable_list&& grads) { | ||||
|   ${thread_lock} | ||||
|   ${asserts} | ||||
| static variable_list ${op}_apply_functional(variable_list&& grads, std::array<bool,${num_vars}> needs_input_grad ${unpacked_saved_vars_signature}) { | ||||
|   IndexRangeGenerator gen; | ||||
|   ${compute_index_ranges} | ||||
|   variable_list grad_inputs(gen.size()); | ||||
|   ${body} | ||||
|   return grad_inputs; | ||||
| } | ||||
|  | ||||
| variable_list ${op}::apply(variable_list&& grads) { | ||||
|   ${thread_lock} | ||||
|   ${asserts} | ||||
|   ${unpacks} | ||||
|   ${compute_needs_input_grad} | ||||
|   return ${op}_apply_functional(std::move(grads), grad_input_mask ${unpacked_saved_vars}); | ||||
| } | ||||
|  | ||||
| void ${op}::compiled_args(CompiledNodeArgs& args) { | ||||
|     ${compiled_args} | ||||
| } | ||||
| @ -100,6 +110,28 @@ variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVaria | ||||
|     ${apply_with_saved_after} | ||||
|     return result; | ||||
| } | ||||
| ivalue_list ${op}::get_state() { | ||||
|   SavedState saved_state; | ||||
|   ${unpacks} | ||||
|   ${get_state} | ||||
|   return saved_state.stack; | ||||
| } | ||||
| ivalue_list ${op}::retrieve_saved(SwapSavedVariables& saved) { | ||||
|   ${apply_with_saved_before} | ||||
|   auto state = get_state(); | ||||
|   ${apply_with_saved_after} | ||||
|   return state; | ||||
| } | ||||
|  | ||||
| functional_apply_t ${op}::get_functional() { | ||||
|   ${compute_needs_input_grad} | ||||
|   return [grad_input_mask](const variable_list& inputs, const std::vector<c10::IValue>& saved) { | ||||
|     SavedState state; | ||||
|     state.stack = saved; | ||||
|     ${saved_var_dequeues} | ||||
|     return ${op}_apply_functional(variable_list(inputs), grad_input_mask ${unpacked_saved_vars}); | ||||
|   }; | ||||
| } | ||||
| """ | ||||
| ) | ||||
|  | ||||
| @ -107,13 +139,23 @@ GRAD_INPUT_MASK = CodeTemplate( | ||||
|     """\ | ||||
|   auto grad_input_mask = std::array<bool, ${n}>{ | ||||
|     ${masks} | ||||
|   };\ | ||||
|   }; | ||||
| """ | ||||
| ) | ||||
|  | ||||
| COMPUTE_NEEDS_INPUT_GRAD = CodeTemplate( | ||||
|     """\ | ||||
| ${ix_ranges} | ||||
| auto grad_input_mask = std::array<bool, ${n}>{ | ||||
|   ${masks} | ||||
| };\ | ||||
| """ | ||||
| ) | ||||
|  | ||||
|  | ||||
| DERIVATIVE_SINGLE = CodeTemplate( | ||||
|     """\ | ||||
| if (task_should_compute_output({ ${name}_ix })) { | ||||
| if (needs_input_grad[std::get<0>(${name}_ix)]) { | ||||
|   auto grad_result = ${derivative}; | ||||
|   copy_range(grad_inputs, ${name}_ix, grad_result); | ||||
| } | ||||
| @ -126,7 +168,7 @@ if (task_should_compute_output({ ${name}_ix })) { | ||||
| # to each `Tensor`(s) of `self`, and the others. | ||||
| DERIVATIVE_SINGLE_FOREACH = CodeTemplate( | ||||
|     """\ | ||||
| if (task_should_compute_output({ ${name}_ix })) { | ||||
| if (needs_input_grad[std::get<0>(${name}_ix)]) { | ||||
|   std::vector<Tensor> grad_result; | ||||
|   grad_result.reserve(grads.size()); | ||||
|   for (const auto & i : c10::irange(grads.size())) { | ||||
| @ -143,7 +185,7 @@ if (task_should_compute_output({ ${name}_ix })) { | ||||
|  | ||||
| DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( | ||||
|     """\ | ||||
|   if (task_should_compute_output({ ${name}_ix })) { | ||||
|   if (needs_input_grad[std::get<0>(${name}_ix)]) { | ||||
|     copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); | ||||
|   } | ||||
| """ | ||||
| @ -151,7 +193,7 @@ DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( | ||||
|  | ||||
| DERIVATIVE_MULTI = CodeTemplate( | ||||
|     """\ | ||||
| if (task_should_compute_output({ ${idx_ranges} })) { | ||||
| if (${needs_input_grad}) { | ||||
|   ${grad_input_mask} | ||||
|   auto grad_result = ${derivative}; | ||||
|   ${copy_ranges} | ||||
| @ -552,10 +594,16 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str | ||||
|     apply_with_saved_before: list[str] = [] | ||||
|     apply_with_saved_after: list[str] = [] | ||||
|  | ||||
|     for arg in info.args_with_derivatives: | ||||
|     unpacked_saved_vars = [] | ||||
|     unpacked_saved_vars_ref_type = [] | ||||
|  | ||||
|     for idx, arg in enumerate(info.args_with_derivatives): | ||||
|         # compute_index_ranges.append(f"auto {arg.name}_ix = {idx};") | ||||
|         if arg.type in TENSOR_LIST_LIKE_CTYPES: | ||||
|             size = f"{arg.name}_size_" | ||||
|             saved_list_sizes.append(f"size_t {arg.name}_size_;") | ||||
|             unpacked_saved_vars.append(f"{arg.name}_size_") | ||||
|             unpacked_saved_vars_ref_type.append("size_t") | ||||
|         else: | ||||
|             size = "1" | ||||
|         compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});") | ||||
| @ -567,6 +615,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str | ||||
|         should_append_raw_getsetdef = False | ||||
|         visit_name = name | ||||
|         uses_cpp_saved_variable_cls = False | ||||
|         unpacked_ref_type = None | ||||
|  | ||||
|         if ( | ||||
|             type == BaseCType(tensorT) | ||||
| @ -591,6 +640,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str | ||||
|             ) | ||||
|             should_append_raw_getsetdef = True | ||||
|             visit_name = f"{name}_" | ||||
|             unpacked_ref_type = "Tensor&" | ||||
|         elif ( | ||||
|             type == BaseCType(tensorListT) | ||||
|             or type == BaseCType(iTensorListRefT) | ||||
| @ -630,6 +680,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str | ||||
|             ) | ||||
|             should_append_raw_getsetdef = True | ||||
|             visit_name = f"{name}_" | ||||
|             unpacked_ref_type = "std::vector<Tensor>&" | ||||
|         elif type == ListCType(OptionalCType(BaseCType(tensorT))): | ||||
|             uses_cpp_saved_variable_cls = True | ||||
|             saved_variables.append(f"std::vector<SavedVariable> {name}_;") | ||||
| @ -652,6 +703,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str | ||||
|             ) | ||||
|             should_append_raw_getsetdef = True | ||||
|             visit_name = f"{name}_" | ||||
|             unpacked_ref_type = "torch::List<std::optional<Tensor>>&" | ||||
|         elif type == BaseCType(intArrayRefT): | ||||
|             saved_variables.append(f"std::vector<int64_t> {name};") | ||||
|             getter_definitions.append( | ||||
| @ -733,6 +785,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str | ||||
|             elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) | ||||
|         ): | ||||
|             saved_variables.append(f"std::vector<at::Scalar> {name};") | ||||
|             unpacked_ref_type = "std::vector<at::Scalar>&" | ||||
|             saved_variables.append(f"bool {name}_released_ = false;") | ||||
|             # Just clear() is sufficient, we don't need to loop and clear each variable. | ||||
|             # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. | ||||
| @ -803,6 +856,14 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|         apply_with_saved_before.append(f"saved.before({visit_name});") | ||||
|         apply_with_saved_after.append(f"saved.after({visit_name});") | ||||
|  | ||||
|         if unpacked_ref_type is None: | ||||
|             # TODO(rzou): should reformulate in terms of type, then ref | ||||
|             unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&" | ||||
|             if unpacked_ref_type.startswith("const "): | ||||
|                 unpacked_ref_type = unpacked_ref_type[6:] | ||||
|         unpacked_saved_vars.append(name) | ||||
|         unpacked_saved_vars_ref_type.append(unpacked_ref_type) | ||||
|  | ||||
|     for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)): | ||||
|         save_var(var, is_output=False) | ||||
|     for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)): | ||||
| @ -816,6 +877,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|         thread_lock = "" | ||||
|  | ||||
|     if uses_retain_variables(info): | ||||
|         unpacked_saved_vars.append("retain_variables") | ||||
|         unpacked_saved_vars_ref_type.append("bool") | ||||
|         will_release_variables = WILL_RELEASE_VARIABLES.substitute() | ||||
|     else: | ||||
|         will_release_variables = "" | ||||
| @ -834,9 +897,11 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|     def emit_derivative( | ||||
|         derivative: Derivative, | ||||
|         args_with_derivatives: Sequence[Binding], | ||||
|         num_grad_inputs: int, | ||||
|     ) -> tuple[bool, str]: | ||||
|         formula = derivative.formula | ||||
|         var_names = derivative.var_names | ||||
|  | ||||
|         if len(var_names) == 1: | ||||
|             checks_any_grad_defined = False | ||||
|             if "not_implemented" not in formula: | ||||
| @ -857,35 +922,45 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|                 derivative_template = DERIVATIVE_SINGLE | ||||
|             return ( | ||||
|                 checks_any_grad_defined, | ||||
|                 derivative_template.substitute(name=var_names[0], derivative=formula), | ||||
|                 derivative_template.substitute( | ||||
|                     name=var_names[0], | ||||
|                     derivative=formula, | ||||
|                     idx=num_grad_inputs, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         else: | ||||
|             if "grad_input_mask" in formula: | ||||
|                 masks = [ | ||||
|                     f"task_should_compute_output({{ {n}_ix }})," for n in var_names | ||||
|                     f"needs_input_grad[std::get<0>({n}_ix)]," for n in var_names | ||||
|                 ] | ||||
|                 grad_input_mask = GRAD_INPUT_MASK.substitute( | ||||
|                     masks=masks, n=len(var_names) | ||||
|                     n=len(var_names), | ||||
|                     masks=masks | ||||
|                 ) | ||||
|             else: | ||||
|                 grad_input_mask = "" | ||||
|             idx_ranges = ", ".join(f"{n}_ix" for n in var_names) | ||||
|             needs_input_grad = [f"needs_input_grad[std::get<0>({var_names[i]}_ix)]" for i in range(len(var_names))] | ||||
|             needs_input_grad = " || ".join(needs_input_grad) | ||||
|             # idx_ranges = ", ".join(f"{n}_ix" for n in var_names) | ||||
|             copy_ranges: list[str] = [] | ||||
|             for i, n in enumerate(var_names): | ||||
|                 copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) | ||||
|                 copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i, idx=num_grad_inputs + i)) | ||||
|             return False, DERIVATIVE_MULTI.substitute( | ||||
|                 idx_ranges=idx_ranges, | ||||
|                 needs_input_grad=needs_input_grad, | ||||
|                 copy_ranges=copy_ranges, | ||||
|                 derivative=formula, | ||||
|                 grad_input_mask=grad_input_mask, | ||||
|             ) | ||||
|  | ||||
|     body.extend(unpack) | ||||
|     num_grad_inputs = 0 | ||||
|  | ||||
|     need_any_grad_defined_var = False | ||||
|     for derivative in info.derivatives: | ||||
|     for idx, derivative in enumerate(info.derivatives): | ||||
|         checks_any_grad_defined, derivative_text = emit_derivative( | ||||
|             derivative, info.args_with_derivatives | ||||
|             derivative, info.args_with_derivatives, num_grad_inputs | ||||
|         ) | ||||
|         num_grad_inputs += len(derivative.var_names) | ||||
|         body.append(derivative_text) | ||||
|         need_any_grad_defined_var |= checks_any_grad_defined | ||||
|     # Since single-output derivative formulas need to check if grads are | ||||
| @ -896,6 +971,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|             "bool any_grad_defined = any_variable_defined(grads);", | ||||
|         ) | ||||
|  | ||||
|  | ||||
|     if info.name in UNTRACEABLE_FUNCTIONS: | ||||
|         superclass = "Node" | ||||
|     else: | ||||
| @ -906,8 +982,41 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|     ) | ||||
|     all_getter_definitions = "\n".join(getter_definitions) | ||||
|  | ||||
|     get_state = "\n".join( | ||||
|         f"saved_state.enqueue({name});" | ||||
|         for name in unpacked_saved_vars | ||||
|     ) | ||||
|     saved_var_dequeues = [] | ||||
|     for typ, name in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars): | ||||
|         if typ.endswith("&"): | ||||
|             typ = typ[:-1] | ||||
|         saved_var_dequeues.append(f"{typ} {name};") | ||||
|         saved_var_dequeues.append(f"state.dequeue({name});") | ||||
|  | ||||
|     masks = [ | ||||
|         f"task_should_compute_output({n})," for n in range(num_grad_inputs) | ||||
|     ] | ||||
|     compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute( | ||||
|         ix_ranges="", | ||||
|         n=num_grad_inputs, | ||||
|         masks=masks); | ||||
|     if len(unpacked_saved_vars) > 0: | ||||
|         unpacked_saved_vars_signature = ", " + ",".join(f"{T} {x}" for T, x in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars)) | ||||
|     else: | ||||
|         unpacked_saved_vars_signature = "" | ||||
|     if len(unpacked_saved_vars) > 0: | ||||
|         unpacked_saved_vars = ", " + ", ".join(unpacked_saved_vars) | ||||
|     else: | ||||
|         unpacked_saved_vars = "" | ||||
|  | ||||
|     return template.substitute( | ||||
|         unpacks="\n".join(unpack), | ||||
|         op=info.op, | ||||
|         saved_var_dequeues="\n".join(saved_var_dequeues), | ||||
|         unpacked_saved_vars=unpacked_saved_vars, | ||||
|         unpacked_saved_vars_signature=unpacked_saved_vars_signature, | ||||
|         compute_needs_input_grad=compute_needs_input_grad, | ||||
|         num_vars=num_grad_inputs, | ||||
|         compute_index_ranges=compute_index_ranges, | ||||
|         saved_variables=saved_variables, | ||||
|         release_variables=release_variables, | ||||
| @ -922,4 +1031,5 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { | ||||
|         compiled_args=compiled_args, | ||||
|         apply_with_saved_before=apply_with_saved_before, | ||||
|         apply_with_saved_after=apply_with_saved_after, | ||||
|         get_state=get_state, | ||||
|     ) | ||||
|  | ||||
| @ -26,6 +26,7 @@ from torch.fx.experimental.proxy_tensor import ( | ||||
|     PythonKeyTracer, | ||||
|     track_tensor_tree, | ||||
| ) | ||||
| import torch.utils._pytree as pytree | ||||
| from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv | ||||
| from torch.fx.traceback import preserve_node_meta, set_stack_trace | ||||
| from torch.utils._traceback import CapturedTraceback | ||||
| @ -54,6 +55,35 @@ def maybe_clone(x): | ||||
|         return clone_preserve_strides(x) | ||||
|     return x | ||||
|  | ||||
| counter = 0 | ||||
|  | ||||
| class OpNamespace: | ||||
|     def __init__(self): | ||||
|         self.next_id = {} | ||||
|  | ||||
|     def add(self, base_name, fn): | ||||
|         if base_name not in self.next_id: | ||||
|             self.next_id[base_name] = 0 | ||||
|         nid = self.next_id[base_name] | ||||
|         name = f"{base_name}_{nid}" | ||||
|         self.next_id[base_name] += 1 | ||||
|         result = Op(name, fn) | ||||
|         torch._dynamo.allow_in_graph(result) | ||||
|         setattr(self, name, result) | ||||
|         return result | ||||
|  | ||||
| class Op: | ||||
|     def __init__(self, name, fn): | ||||
|         self.fn = fn | ||||
|         self.__name__ = name | ||||
|         self.__module__ = "torch._dynamo.compiled_autograd.ops" | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         return self.fn(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| ops = OpNamespace() | ||||
|  | ||||
|  | ||||
| class AutogradCompilerInstance: | ||||
|     def __init__(self, compiler_fn) -> None: | ||||
| @ -70,6 +100,7 @@ class AutogradCompilerInstance: | ||||
|         self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") | ||||
|         self.hooks_proxy: Optional[Proxy] = None | ||||
|         self.graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] | ||||
|         self.old_inline_behavior = True | ||||
|  | ||||
|     def wrap_fake(self, x, source): | ||||
|         assert isinstance(x, torch.Tensor) | ||||
| @ -187,6 +218,55 @@ class AutogradCompilerInstance: | ||||
|             self.bind_tensors_to_proxies(grad_ins, proxies) | ||||
|         return tuple(grad_ins) | ||||
|  | ||||
|     def allocate_dummy(self, *examples): | ||||
|         with disable_proxy_modes_tracing(): | ||||
|             return torch.zeros(0) | ||||
|  | ||||
|     def apply_functional(self, fn, inputs, stack, num_outputs, debug_name): | ||||
|         if self.old_inline_behavior: | ||||
|             result = fn(inputs, *stack) | ||||
|             return result | ||||
|         # TODO: if the node is a python autograd.Function or a CompiledFunctionBackward, | ||||
|         # we should probably "plop" the subgraph into the graph instead | ||||
|         # of allow_in_graph the node through Dynamo. | ||||
|         proxy_inputs, proxy_stack = pytree.tree_map(lambda t: self.to_proxy(t) if isinstance(t, torch.Tensor) else t,  (inputs, stack)) | ||||
|         op = ops.add(debug_name, fn) | ||||
|         proxy_out = self.fx_tracer.create_proxy( | ||||
|             "call_function", | ||||
|             op, | ||||
|             args=(proxy_inputs, *proxy_stack), | ||||
|             kwargs={}) | ||||
|         result = [self.allocate_dummy(*inputs, *stack) for _ in range(num_outputs)] | ||||
|         self.bind_tensors_to_proxies(result, [proxy_out[i] for i in range(num_outputs)]) | ||||
|         return result | ||||
|  | ||||
|     def validate_outputs(self, fn, outputs, stack, _0, _1): | ||||
|         if self.old_inline_behavior: | ||||
|             return fn(outputs, *stack) | ||||
|         proxy_outputs, proxy_stack = pytree.tree_map(lambda t: self.to_proxy(t) if isinstance(t, torch.Tensor) else t, (outputs, stack)) | ||||
|         op = ops.add("validate_outputs", fn) | ||||
|         new_proxy_outputs = self.fx_tracer.create_proxy( | ||||
|             "call_function", | ||||
|             op, | ||||
|             args=(proxy_outputs, *proxy_stack), | ||||
|             kwargs={}) | ||||
|         self.bind_tensors_to_proxies(outputs, new_proxy_outputs) | ||||
|         return outputs | ||||
|  | ||||
|     def accumulate(self, old_var, new_var): | ||||
|         if self.old_inline_behavior: | ||||
|             return torch.add(old_var, new_var) | ||||
|         old_var_proxy = self.to_proxy(old_var) | ||||
|         new_var_proxy = self.to_proxy(new_var) | ||||
|         proxy_out = self.fx_tracer.create_proxy( | ||||
|             "call_function", | ||||
|             torch.add, | ||||
|             args=(old_var_proxy, new_var_proxy), | ||||
|             kwargs={}) | ||||
|         result = self.allocate_dummy(old_var) | ||||
|         self.bind_tensors_to_proxies([result], [proxy_out]) | ||||
|         return result | ||||
|  | ||||
|     def proxy_call_hook(self, hook, *args, **kwargs): | ||||
|         return self.fx_tracer.create_proxy( | ||||
|             "call_function", | ||||
| @ -710,8 +790,6 @@ class AutogradCompilerInstance: | ||||
|             return [self.to_proxy(x) for x in t] | ||||
|         if isinstance(t, tuple): | ||||
|             return tuple(self.to_proxy(x) for x in t) | ||||
|         # can it be torch.SymInt as the code used to imply? | ||||
|         assert isinstance(t, torch.Tensor) | ||||
|         proxy_tensor = fetch_object_proxy(self.fx_tracer, t) | ||||
|         assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) | ||||
|         return proxy_tensor.proxy | ||||
|  | ||||
| @ -3273,6 +3273,7 @@ if torch.distributed.is_available(): | ||||
| MOD_INLINELIST = [ | ||||
|     "torch._decomp", | ||||
|     "torch._dynamo._trace_wrapped_higher_order_op", | ||||
|     "torch._dynamo.compiled_autograd.ops", | ||||
|     "torch._dynamo.comptime", | ||||
|     "torch._dynamo.polyfills", | ||||
|     "torch._functorch.autograd_function", | ||||
|  | ||||
| @ -530,9 +530,16 @@ variable_list AutogradContext::get_saved_variables() const { | ||||
|   variable_list saved; | ||||
|   saved.reserve(saved_variables_.size()); | ||||
|   auto ptr = grad_fn_.lock(); | ||||
|   TORCH_INTERNAL_ASSERT(ptr); | ||||
|   for (auto& var : saved_variables_) { | ||||
|     saved.push_back(var.unpack(ptr)); | ||||
|   // TORCH_INTERNAL_ASSERT(ptr); | ||||
|   // TODO(rzou): hacky, can do this in a more legit way | ||||
|   if (ptr) { | ||||
|     for (auto& var : saved_variables_) { | ||||
|       saved.push_back(var.unpack(ptr)); | ||||
|     } | ||||
|   } else { | ||||
|     for (auto& var : saved_variables_) { | ||||
|       saved.push_back(var.unpack()); | ||||
|     } | ||||
|   } | ||||
|   return saved; | ||||
| } | ||||
| @ -543,6 +550,7 @@ bool AutogradContext::needs_input_grad(size_t output_edge_index) const { | ||||
|   return ptr->task_should_compute_output(output_edge_index); | ||||
| } | ||||
|  | ||||
| // TODO(rzou): might segfault, need to make this functional | ||||
| bool AutogradContext::needs_input_grad( | ||||
|     std::initializer_list<IndexRange> idxs) const { | ||||
|   auto ptr = grad_fn_.lock(); | ||||
|  | ||||
| @ -241,6 +241,111 @@ struct CppNode : public Node { | ||||
|     saved.after(output_info_); | ||||
|     return results; | ||||
|   } | ||||
|  | ||||
|   functional_apply_t get_functional() override { | ||||
|     auto name = this->name(); | ||||
|  | ||||
|     // TODO(rzou): probably need to pre compute needs_input_grad | ||||
|     return [name](const variable_list& inputs, const std::vector<c10::IValue>& saved) { | ||||
|       SavedState state; | ||||
|       state.stack = saved; | ||||
|       auto ctx = AutogradContext(); | ||||
|       std::vector<VariableInfo> output_info; | ||||
|       std::vector<bool> is_variable_input; | ||||
|       state.dequeue(ctx.saved_data); | ||||
|       state.dequeue(ctx.saved_variables_); | ||||
|       state.dequeue(ctx.materialize_grads_); | ||||
|       state.dequeue(output_info); | ||||
|       state.dequeue(is_variable_input); | ||||
|  | ||||
|       // TODO(rzou): refactor to share code with CppNode<T>::apply | ||||
|       at::OptionalDeviceGuard _device_guard; | ||||
|       auto num_inputs = inputs.size(); | ||||
|       variable_list backward_inputs; | ||||
|       backward_inputs.reserve(num_inputs); | ||||
|       for (const auto i : c10::irange(num_inputs)) { | ||||
|         if (inputs[i].defined() || !ctx.materialize_grads_) { | ||||
|           backward_inputs.emplace_back(inputs[i]); | ||||
|         } else { | ||||
|           backward_inputs.emplace_back(output_info[i].zeros(_device_guard)); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       auto outputs = T::backward(&ctx, inputs); | ||||
|  | ||||
|       const auto num_forward_inputs = | ||||
|           static_cast<int64_t>(is_variable_input.size()); | ||||
|       auto num_outputs = static_cast<int64_t>(outputs.size()); | ||||
|       // Returning too many results is ok, but only as long as they're all | ||||
|       // undefined. Truncate the result vector in that case. | ||||
|       if (num_outputs > num_forward_inputs) { | ||||
|         bool all_undef = true; | ||||
|         for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { | ||||
|           all_undef &= (!outputs[i].defined()); | ||||
|         } | ||||
|         if (all_undef) { | ||||
|           outputs.resize(num_forward_inputs); | ||||
|           num_outputs = num_forward_inputs; | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       if (num_outputs != num_forward_inputs) { | ||||
|         std::string msg("function "); | ||||
|         msg += name + " returned an incorrect number of gradients (expected "; | ||||
|         msg += std::to_string(num_forward_inputs) + ", got "; | ||||
|         msg += std::to_string(num_outputs) + ")"; | ||||
|         throw std::runtime_error(msg); | ||||
|       } | ||||
|  | ||||
|       variable_list results; | ||||
|       results.reserve(num_outputs); | ||||
|       for (const auto i : c10::irange(num_outputs)) { | ||||
|         if (!is_variable_input[i]) { | ||||
|           if (outputs[i].defined()) { | ||||
|             std::string msg("function "); | ||||
|             msg += name + | ||||
|                 " returned a gradient different that is defined at position "; | ||||
|             msg += std::to_string(i + 1) + | ||||
|                 ", std the corresponding forward input was not a Variable"; | ||||
|             throw std::runtime_error(msg); | ||||
|           } | ||||
|           continue; | ||||
|         } | ||||
|         results.emplace_back(outputs[i]); | ||||
|       } | ||||
|       return results; | ||||
|     }; | ||||
|   } | ||||
|   ivalue_list retrieve_saved(SwapSavedVariables& saved) override { | ||||
|     saved.before(ctx_.saved_data); | ||||
|     TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty()); | ||||
|     TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty()); | ||||
|     saved.before(ctx_.saved_variables_); | ||||
|     TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty()); | ||||
|     saved.before(ctx_.materialize_grads_); | ||||
|     saved.before(ctx_.has_freed_buffers_); | ||||
|     saved.before(input_info_); | ||||
|     saved.before(output_info_); | ||||
|  | ||||
|     SavedState state; | ||||
|     state.enqueue(ctx_.saved_data); | ||||
|     state.enqueue(ctx_.saved_variables_, shared_from_this()); | ||||
|     state.enqueue(ctx_.materialize_grads_); | ||||
|     state.enqueue(output_info_); | ||||
|     state.enqueue(is_variable_input_); | ||||
|  | ||||
|     saved.after(ctx_.saved_data); | ||||
|     TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty()); | ||||
|     TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty()); | ||||
|     saved.after(ctx_.saved_variables_); | ||||
|     TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty()); | ||||
|     saved.after(ctx_.materialize_grads_); | ||||
|     saved.after(ctx_.has_freed_buffers_); | ||||
|     saved.after(input_info_); | ||||
|     saved.after(output_info_); | ||||
|  | ||||
|     return state.stack; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct ExtractVariables : IterArgs<ExtractVariables> { | ||||
|  | ||||
| @ -859,18 +859,38 @@ void validate_outputs( | ||||
|     const edge_list& edges, | ||||
|     variable_list& grads, | ||||
|     const std::function<std::string(const std::string&)>& format_error) { | ||||
|   if (grads.size() != edges.size()) { | ||||
|   // TODO(rzou): probably too many heap allocations here... | ||||
|   auto input_metadata = collect_input_metadata(edges); | ||||
|   validate_outputs(input_metadata, grads, format_error); | ||||
| } | ||||
|  | ||||
| std::vector<c10::optional<InputMetadata>> collect_input_metadata(const edge_list& edges) { | ||||
|   std::vector<c10::optional<InputMetadata>> input_metadata; | ||||
|   for (const auto& edge : edges) { | ||||
|     if (!edge.is_valid()) { | ||||
|       input_metadata.emplace_back(c10::nullopt); | ||||
|       continue; | ||||
|     } | ||||
|     input_metadata.emplace_back(edge.function->input_metadata(edge.input_nr)); | ||||
|   } | ||||
|   return input_metadata; | ||||
| } | ||||
|  | ||||
| void validate_outputs( | ||||
|     const std::vector<c10::optional<InputMetadata>>& input_metadata, | ||||
|     variable_list& grads, | ||||
|     const std::function<std::string(const std::string&)>& format_error) { | ||||
|   if (grads.size() != input_metadata.size()) { | ||||
|     std::stringstream ss; | ||||
|     ss << "invalid number of gradients - expected "; | ||||
|     ss << edges.size() << ", but got " << grads.size(); | ||||
|     ss << input_metadata.size() << ", but got " << grads.size(); | ||||
|     TORCH_CHECK(false, format_error(ss.str())); | ||||
|   } | ||||
|   for (const auto i : c10::irange(grads.size())) { | ||||
|     const auto& edge = edges[i]; | ||||
|     if (!edge.is_valid()) | ||||
|     if (!input_metadata[i].has_value()) { | ||||
|       continue; | ||||
|  | ||||
|     const auto& metadata = edge.function->input_metadata(edge.input_nr); | ||||
|     } | ||||
|     const auto& metadata = input_metadata[i].value(); | ||||
|     auto& grad = grads[i]; | ||||
|     if (!grad.defined()) { | ||||
|       // FIXME: TestJit.test_ge_optimized fails this assertion. | ||||
|  | ||||
| @ -43,6 +43,11 @@ TORCH_API void validate_outputs( | ||||
|     const edge_list& edges, | ||||
|     variable_list& grads, | ||||
|     const std::function<std::string(const std::string&)>& format_error); | ||||
| TORCH_API void validate_outputs( | ||||
|     const std::vector<c10::optional<InputMetadata>>& input_metadata, | ||||
|     variable_list& grads, | ||||
|     const std::function<std::string(const std::string&)>& format_error); | ||||
| TORCH_API std::vector<c10::optional<InputMetadata>> collect_input_metadata(const edge_list& edges); | ||||
|  | ||||
| struct NodeTask { | ||||
|   std::weak_ptr<GraphTask> base_; | ||||
|  | ||||
| @ -34,8 +34,12 @@ using tensor_list = std::vector<at::Tensor>; | ||||
| using variable_list = std::vector<Variable>; | ||||
| using edge_list = std::vector<Edge>; | ||||
| using saved_variable_list = std::vector<SavedVariable>; | ||||
| using ivalue_list = std::vector<c10::IValue>; | ||||
| using functional_apply_t = std::function< | ||||
|     variable_list(const variable_list&, const std::vector<c10::IValue>&)>; | ||||
| using IndexRange = std::pair<size_t, size_t>; | ||||
| using torch::dynamo::autograd::CompiledNodeArgs; | ||||
| using torch::dynamo::autograd::SavedState; | ||||
| using torch::dynamo::autograd::SwapSavedVariables; | ||||
|  | ||||
| // Custom deleter to prevent stack overflows. | ||||
| @ -604,6 +608,17 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | ||||
|         std::string("apply_with_saved not implemented: ") + name()); | ||||
|   } | ||||
|  | ||||
|   virtual ivalue_list retrieve_saved(SwapSavedVariables& saved) { | ||||
|     throw std::runtime_error( | ||||
|         std::string("retrieve_saved not implemented: ") + name()); | ||||
|   } | ||||
|   virtual std::function< | ||||
|       variable_list(const variable_list&, const std::vector<c10::IValue>&)> | ||||
|   get_functional() { | ||||
|     throw std::runtime_error( | ||||
|         std::string("get_functional not implemented: ") + name()); | ||||
|   } | ||||
|  | ||||
|  protected: | ||||
|   /// Performs the `Node`'s actual operation. | ||||
|   virtual variable_list apply(variable_list&& inputs) = 0; | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| namespace torch::dynamo::autograd { | ||||
| class CompiledNodeArgs; | ||||
| class SwapSavedVariables; | ||||
| struct SavedState; | ||||
| } // namespace torch::dynamo::autograd | ||||
|  | ||||
| // A hook that's called on gradients | ||||
|  | ||||
| @ -103,4 +103,40 @@ variable_list AccumulateGrad::apply_with_saved( | ||||
|   return variable_list(); | ||||
| } | ||||
|  | ||||
| ivalue_list AccumulateGrad::retrieve_saved(SwapSavedVariables& saved) { | ||||
|   auto should_visit = variable.defined() && variable.requires_grad(); | ||||
|   if (should_visit) { | ||||
|     saved.before(variable); | ||||
|   } | ||||
|  | ||||
|   SavedState state; | ||||
|   state.enqueue(variable); | ||||
|  | ||||
|   if (should_visit) { | ||||
|     saved.after(variable); | ||||
|   } | ||||
|  | ||||
|   return state.stack; | ||||
| } | ||||
|  | ||||
| functional_apply_t AccumulateGrad::get_functional() { | ||||
|   return [](const variable_list& inputs, | ||||
|             const std::vector<c10::IValue>& saved) -> variable_list { | ||||
|     SavedState state; | ||||
|     state.stack = saved; | ||||
|     Variable foo; | ||||
|     state.dequeue(foo); | ||||
|     if (!(foo.defined() && foo.requires_grad()) || !inputs[0].defined()) { | ||||
|       return variable_list(); | ||||
|     } | ||||
|     // op is intentionally static | ||||
|     static auto op = c10::Dispatcher::singleton() | ||||
|                          .findSchemaOrThrow("inductor::accumulate_grad_", "") | ||||
|                          .typed<void(const at::Tensor&, const at::Tensor&)>(); | ||||
|     op.call(foo, inputs[0]); | ||||
|     // TODO(rzou): tensor_post_acc_grad_hooks | ||||
|     return variable_list(); | ||||
|   }; | ||||
| } | ||||
|  | ||||
| } // namespace torch::autograd | ||||
|  | ||||
| @ -267,6 +267,9 @@ struct TORCH_API AccumulateGrad : public Node { | ||||
|       const variable_list& inputs, | ||||
|       SwapSavedVariables& saved) override; | ||||
|  | ||||
|   ivalue_list retrieve_saved(SwapSavedVariables& saved) override; | ||||
|   functional_apply_t get_functional() override; | ||||
|  | ||||
|   Variable variable; | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -77,5 +77,22 @@ variable_list GraphRoot::apply_with_saved( | ||||
|   saved.after(outputs); | ||||
|   return result; | ||||
| } | ||||
| ivalue_list GraphRoot::retrieve_saved(SwapSavedVariables& saved) { | ||||
|   saved.before(outputs); | ||||
|   SavedState state; | ||||
|   state.enqueue(outputs); | ||||
|   saved.after(outputs); | ||||
|   return state.stack; | ||||
| } | ||||
| functional_apply_t GraphRoot::get_functional() { | ||||
|   return [](const variable_list& inputs, | ||||
|             const std::vector<c10::IValue>& saved) -> variable_list { | ||||
|     SavedState state; | ||||
|     state.stack = saved; | ||||
|     variable_list outputs; | ||||
|     state.dequeue(outputs); | ||||
|     return outputs; | ||||
|   }; | ||||
| } | ||||
|  | ||||
| } // namespace torch::autograd | ||||
|  | ||||
| @ -97,6 +97,8 @@ struct TORCH_API GraphRoot : public Node { | ||||
|   variable_list apply_with_saved( | ||||
|       const variable_list& inputs, | ||||
|       SwapSavedVariables& saved) override; | ||||
|   ivalue_list retrieve_saved(SwapSavedVariables& saved) override; | ||||
|   functional_apply_t get_functional() override; | ||||
|  | ||||
|   variable_list outputs; | ||||
| }; | ||||
|  | ||||
| @ -103,7 +103,7 @@ struct TORCH_API InputMetadata { | ||||
|   bool maybe_expandable_to(const at::Tensor& grad) const; | ||||
|  | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|   const at::TensorOptions options_; | ||||
|   at::TensorOptions options_; | ||||
|   MetadataShape shape_; | ||||
|   c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device()); | ||||
|   bool is_tensor_subclass_ = false; | ||||
|  | ||||
| @ -25,6 +25,7 @@ | ||||
| #include <torch/csrc/autograd/saved_variable.h> | ||||
| #include <torch/csrc/autograd/utils/wrap_outputs.h> | ||||
| #include <torch/csrc/dynamo/compiled_autograd.h> | ||||
| #include <torch/csrc/dynamo/python_compiled_autograd.h> | ||||
| #include <torch/csrc/jit/frontend/tracer.h> | ||||
| #include <torch/csrc/jit/ir/ir.h> | ||||
| #include <torch/csrc/jit/python/pybind_utils.h> | ||||
| @ -396,6 +397,99 @@ variable_list PyNode::apply_with_saved( | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| ivalue_list PyNode::retrieve_saved(SwapSavedVariables& saved) { | ||||
|   auto f = (THPFunction*)obj; | ||||
|   saved.before(f->compiled_autograd_symints); | ||||
|   saved.before(f->saved_variables); | ||||
|   saved.before(f->needs_input_grad); | ||||
|   saved.before(f->materialize_non_diff_grads); | ||||
|   saved.before(f->output_info); | ||||
|   saved.before(f->input_info); | ||||
|  | ||||
|   SavedState state; | ||||
|   state.enqueue(f->compiled_autograd_symints); | ||||
|   state.enqueue(f->saved_variables, shared_from_this()); | ||||
|   // state.enqueue(f->needs_input_grad); | ||||
|   // state.enqueue(f->materialize_non_diff_grads); | ||||
|   // state.enqueue(f->output_info); | ||||
|   // state.enqueue(f->input_info); | ||||
|  | ||||
|   saved.after(f->compiled_autograd_symints); | ||||
|   saved.after(f->saved_variables); | ||||
|   saved.after(f->needs_input_grad); | ||||
|   saved.after(f->materialize_non_diff_grads); | ||||
|   saved.after(f->output_info); | ||||
|   saved.after(f->input_info); | ||||
|  | ||||
|   state.enqueue(f->compiled_autograd_symints); | ||||
|   state.enqueue(f->saved_variables, shared_from_this()); | ||||
|   // state.enqueue(f->needs_input_grad); | ||||
|   // state.enqueue(f->materialize_non_diff_grads); | ||||
|   // state.enqueue(f->output_info); | ||||
|   // state.enqueue(f->input_info); | ||||
|  | ||||
|   return state.stack; | ||||
| } | ||||
|  | ||||
| // TODO(rzou): compiled autograd needs special handling of the following. | ||||
| std::function< | ||||
|     variable_list(const variable_list&, const std::vector<c10::IValue>&)> | ||||
| PyNode::get_functional() { | ||||
|   auto node = std::static_pointer_cast<PyNode>(shared_from_this()); | ||||
|   // TODO(rzou): probably need to pre compute needs_input_grad | ||||
|   return | ||||
|       [node]( | ||||
|           const variable_list& inputs, const std::vector<c10::IValue>& saved) { | ||||
|         SavedState state; | ||||
|         state.stack = saved; | ||||
|  | ||||
|         auto f = (THPFunction*)node->obj; | ||||
|  | ||||
|         state.dequeue(f->compiled_autograd_symints); | ||||
|         state.dequeue(f->saved_variables); | ||||
|         // state.dequeue(f->needs_input_grad); | ||||
|         // state.dequeue(f->materialize_non_diff_grads); | ||||
|         // state.dequeue(f->output_info); | ||||
|         // state.dequeue(f->input_info); | ||||
|  | ||||
|         f->compiled_autograd_tracing = true; | ||||
|         variable_list result; | ||||
|         if (!node->compiled_autograd_should_lift()) { | ||||
|           if (node->_backward_state_idx.has_value()) { | ||||
|             PyObject* r = PyObject_CallMethod( | ||||
|                 torch::dynamo::autograd::current_py_compiler(), | ||||
|                 "bind_backward_state", | ||||
|                 "i", | ||||
|                 *node->_backward_state_idx); | ||||
|             if (r == nullptr) { | ||||
|               throw python_error(); | ||||
|             } | ||||
|             THPObjectPtr prior(f->compiled_autograd_backward_state); | ||||
|             f->compiled_autograd_backward_state = r; | ||||
|             result = node->apply(variable_list(inputs)); | ||||
|             Py_CLEAR(f->compiled_autograd_backward_state); | ||||
|             f->compiled_autograd_backward_state = prior.release(); | ||||
|           } else { | ||||
|             result = node->apply(variable_list(inputs)); | ||||
|           } | ||||
|         } else { | ||||
|           result = node->defer_to_dynamo( | ||||
|               variable_list(inputs), | ||||
|               torch::dynamo::autograd::current_py_compiler()); | ||||
|         } | ||||
|         f->compiled_autograd_tracing = false; | ||||
|  | ||||
|         state.dequeue(f->compiled_autograd_symints); | ||||
|         state.dequeue(f->saved_variables); | ||||
|         // state.dequeue(f->needs_input_grad); | ||||
|         // state.dequeue(f->materialize_non_diff_grads); | ||||
|         // state.dequeue(f->output_info); | ||||
|         // state.dequeue(f->input_info); | ||||
|  | ||||
|         return result; | ||||
|       }; | ||||
| } | ||||
|  | ||||
| PyObject* PyNode::to_py_args( | ||||
|     const variable_list& inputs, | ||||
|     at::OptionalDeviceGuard* device_guard) { | ||||
|  | ||||
| @ -70,6 +70,11 @@ struct PyNode : public Node { | ||||
|       Py_DECREF(obj); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::function< | ||||
|       variable_list(const variable_list&, const std::vector<c10::IValue>&)> | ||||
|   get_functional() override; | ||||
|   ivalue_list retrieve_saved(SwapSavedVariables& saved) override;  | ||||
| }; | ||||
|  | ||||
| /** | ||||
|  | ||||
| @ -898,6 +898,321 @@ class SwapSavedVariables { | ||||
|   StashedVars<at::IValue> stashed_ivalues; | ||||
| }; | ||||
|  | ||||
| struct SavedState { | ||||
|   std::vector<at::IValue> stack; | ||||
|   int64_t idx = 0; | ||||
|  | ||||
|   void enqueue( | ||||
|       const SavedVariable& sv, | ||||
|       const std::shared_ptr<Node>& saved_for) { | ||||
|     stack.emplace_back(sv.unpack(saved_for)); | ||||
|   } | ||||
|   void dequeue(SavedVariable& sv) { | ||||
|     sv = SavedVariable(stack[idx++].toTensor(), /*is_output*/ true); | ||||
|   } | ||||
|  | ||||
|   void enqueue( | ||||
|       const std::vector<SavedVariable>& sv, | ||||
|       const std::shared_ptr<Node>& saved_for) { | ||||
|     enqueue(static_cast<int64_t>(sv.size())); | ||||
|     for (const auto& v : sv) { | ||||
|       enqueue(v, saved_for); | ||||
|     } | ||||
|   } | ||||
|   void dequeue(std::vector<SavedVariable>& sv) { | ||||
|     int64_t size = 0; | ||||
|     dequeue(size); | ||||
|     sv.clear(); | ||||
|     for (int64_t idx = 0; idx < size; idx++) { | ||||
|       sv.emplace_back(); | ||||
|       dequeue(sv.back()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   /* | ||||
|   void enqueue(const PyObject*& t) { | ||||
|     enqueue_ivalue(t); | ||||
|   } | ||||
|   void dequeue(PyObject*& t) { | ||||
|     dequeue_ivalue(t); | ||||
|   } | ||||
|   */ | ||||
|  | ||||
|   void enqueue(const VariableInfo& t) { | ||||
|     enqueue(t.layout); | ||||
|     enqueue(t.device); | ||||
|     enqueue(t.scalar_type); | ||||
|     enqueue(t.size); | ||||
|     enqueue(t.requires_grad); | ||||
|     enqueue(t.is_empty); | ||||
|   } | ||||
|   void dequeue(VariableInfo& t) { | ||||
|     dequeue(t.layout); | ||||
|     dequeue(t.device); | ||||
|     dequeue(t.scalar_type); | ||||
|     dequeue(t.size); | ||||
|     dequeue(t.requires_grad); | ||||
|     dequeue(t.is_empty); | ||||
|   } | ||||
|  | ||||
|   void enqueue(size_t t) { | ||||
|     enqueue(static_cast<int64_t>(t)); | ||||
|   } | ||||
|   void dequeue(size_t& t) { | ||||
|     int64_t tmp = 0; | ||||
|     dequeue(tmp); | ||||
|     t = static_cast<size_t>(tmp); | ||||
|   } | ||||
|  | ||||
|   // TODO: probably wildly inefficient | ||||
|   template <class T> | ||||
|   void enqueue(const c10::List<T> t) { | ||||
|     enqueue(t.vec()); | ||||
|   } | ||||
|   template <class T> | ||||
|   void dequeue(c10::List<T>& t) { | ||||
|     std::vector<T> tmp; | ||||
|     dequeue(tmp); | ||||
|     t = c10::List<T>(tmp); | ||||
|   } | ||||
|  | ||||
|   void enqueue(const TypeAndSize& value) { | ||||
|     enqueue(value.sym_sizes); | ||||
|     enqueue(value.options); | ||||
|   } | ||||
|   void dequeue(TypeAndSize& value) { | ||||
|     dequeue(value.sym_sizes); | ||||
|     dequeue(value.options); | ||||
|   } | ||||
|  | ||||
|   void enqueue(const InputMetadata& value) { | ||||
|     enqueue(value.options()); | ||||
|     enqueue(value.shape_as_dim_vector().vec()); | ||||
|     enqueue(value.is_tensor_subclass()); | ||||
|     TORCH_INTERNAL_ASSERT(!value.is_nested_tensor()); | ||||
|   } | ||||
|   // Special case: InputMetadata has no copy ctor | ||||
|   // TODO(rzou): ?? | ||||
|   void dequeue(InputMetadata& value) { | ||||
|     at::TensorOptions options; | ||||
|     dequeue(options); | ||||
|     std::vector<at::SymInt> shape; | ||||
|     dequeue(shape); | ||||
|     bool is_tensor_subclass = false; | ||||
|     dequeue(is_tensor_subclass); | ||||
|     SymIntSmallVec sym_shape; | ||||
|     for (const auto& s : shape) { | ||||
|       sym_shape.emplace_back(s); | ||||
|     } | ||||
|     value = InputMetadata(options, sym_shape, is_tensor_subclass, false); | ||||
|   } | ||||
|  | ||||
|   void enqueue(const ska::flat_hash_map<std::string, at::IValue>& dct) { | ||||
|     std::vector<std::string> keys; | ||||
|     std::vector<at::IValue> values; | ||||
|     for (const auto& [key, value] : dct) { | ||||
|       keys.emplace_back(key); | ||||
|       values.emplace_back(value); | ||||
|     } | ||||
|     enqueue(keys); | ||||
|     enqueue(values); | ||||
|   } | ||||
|   void enqueue(const at::IValue& iv) { | ||||
|     stack.emplace_back(iv); | ||||
|   } | ||||
|   void dequeue(at::IValue& iv) { | ||||
|     iv = stack[idx++]; | ||||
|   } | ||||
|   void dequeue(ska::flat_hash_map<std::string, at::IValue>& dct) { | ||||
|     std::vector<std::string> keys; | ||||
|     std::vector<at::IValue> values; | ||||
|     dequeue(keys); | ||||
|     dequeue(values); | ||||
|     dct.clear(); | ||||
|     for (const auto i : c10::irange(keys.size())) { | ||||
|       dct.insert({keys[i], values[i]}); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void enqueue(const at::TensorOptions& value) { | ||||
|     enqueue(value.requires_grad_opt()); | ||||
|     enqueue(value.memory_format_opt()); | ||||
|     enqueue(value.device_opt()); | ||||
|     enqueue(value.dtype_opt()); | ||||
|     enqueue(value.layout_opt()); | ||||
|     enqueue(value.pinned_memory_opt()); | ||||
|   } | ||||
|   void dequeue(at::TensorOptions& value) { | ||||
|     auto result = at::TensorOptions(); | ||||
|     c10::optional<bool> requires_grad_opt; | ||||
|     dequeue(requires_grad_opt); | ||||
|     if (requires_grad_opt) { | ||||
|       result = result.requires_grad(*requires_grad_opt); | ||||
|     } | ||||
|     c10::optional<c10::MemoryFormat> memory_format_opt; | ||||
|     dequeue(memory_format_opt); | ||||
|     if (memory_format_opt) { | ||||
|       result = result.memory_format(*memory_format_opt); | ||||
|     } | ||||
|     c10::optional<c10::Device> device_opt; | ||||
|     dequeue(device_opt); | ||||
|     if (device_opt) { | ||||
|       result = result.device(*device_opt); | ||||
|     } | ||||
|     c10::optional<caffe2::TypeMeta> dtype_opt; | ||||
|     dequeue(dtype_opt); | ||||
|     if (dtype_opt) { | ||||
|       result = result.dtype(*dtype_opt); | ||||
|     } | ||||
|     c10::optional<c10::Layout> layout_opt; | ||||
|     dequeue(layout_opt); | ||||
|     if (layout_opt) { | ||||
|       result = result.layout(*layout_opt); | ||||
|     } | ||||
|     c10::optional<bool> pinned_memory_opt; | ||||
|     dequeue(pinned_memory_opt); | ||||
|     if (pinned_memory_opt) { | ||||
|       result = result.pinned_memory(*pinned_memory_opt); | ||||
|     } | ||||
|     value = result; | ||||
|   } | ||||
|  | ||||
|   void enqueue(const caffe2::TypeMeta& value) { | ||||
|     enqueue(at::typeMetaToScalarType(value)); | ||||
|   } | ||||
|   void dequeue(caffe2::TypeMeta& value) { | ||||
|     at::ScalarType result = at::kFloat; | ||||
|     dequeue(result); | ||||
|     value = caffe2::TypeMeta::fromScalarType(result); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   void enqueue(const c10::OptionalArray<T>& t) { | ||||
|     enqueue(t.list); | ||||
|   } | ||||
|   template <typename T> | ||||
|   void dequeue(c10::OptionalArray<T>& t) { | ||||
|     dequeue(t.list); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   void enqueue(const std::optional<T>& t) { | ||||
|     enqueue(t.has_value()); | ||||
|     if (t.has_value()) { | ||||
|       enqueue(*t); | ||||
|     } | ||||
|   } | ||||
|   template <typename T> | ||||
|   void dequeue(c10::optional<T>& value) { | ||||
|     bool has_value = false; | ||||
|     dequeue(has_value); | ||||
|     T tmp; | ||||
|     if (has_value) { | ||||
|       dequeue(tmp); | ||||
|     } | ||||
|     value = tmp; | ||||
|   } | ||||
|  | ||||
|   void enqueue(const at::TensorGeometry& t) { | ||||
|     enqueue(t.sym_sizes().vec()); | ||||
|     enqueue(t.sym_strides().vec()); | ||||
|     enqueue(t.sym_storage_offset()); | ||||
|   } | ||||
|   void dequeue(at::TensorGeometry& t) { | ||||
|     std::vector<at::SymInt> sym_sizes; | ||||
|     std::vector<at::SymInt> sym_strides; | ||||
|     at::SymInt sym_storage_offset; | ||||
|     dequeue(sym_sizes); | ||||
|     dequeue(sym_strides); | ||||
|     dequeue(sym_storage_offset); | ||||
|     t = at::TensorGeometry(sym_sizes, sym_strides, sym_storage_offset); | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   void enqueue(const std::vector<T>& t) { | ||||
|     enqueue(static_cast<int64_t>(t.size())); | ||||
|     for (const T& i : t) { | ||||
|       enqueue(i); | ||||
|     } | ||||
|   } | ||||
|   template <typename T> | ||||
|   void dequeue(std::vector<T>& t) { | ||||
|     int64_t size = 0; | ||||
|     dequeue(size); | ||||
|     t.clear(); | ||||
|     for (int64_t idx = 0; idx < size; idx++) { | ||||
|       t.emplace_back(); | ||||
|       dequeue(t.back()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void enqueue(const c10::SymInt& t) { | ||||
|     stack.emplace_back(t); | ||||
|   } | ||||
|   void dequeue(c10::SymInt& t) { | ||||
|     t = stack[idx++].toSymInt(); | ||||
|   } | ||||
|  | ||||
|   void enqueue(int64_t t) { | ||||
|     stack.emplace_back(t); | ||||
|   } | ||||
|   void dequeue(int64_t& t) { | ||||
|     t = stack[idx++].toInt(); | ||||
|   } | ||||
|  | ||||
|   void enqueue(const std::vector<c10::SymInt>& t) { | ||||
|     enqueue_ivalue(t); | ||||
|   } | ||||
|   void dequeue(std::vector<c10::SymInt>& t) { | ||||
|     t = stack[idx++].toSymIntVector(); | ||||
|   } | ||||
|  | ||||
|   void enqueue(const std::vector<int64_t>& t) { | ||||
|     enqueue_ivalue(t); | ||||
|   } | ||||
|   void dequeue(std::vector<int64_t>& t) { | ||||
|     t = stack[idx++].toIntVector(); | ||||
|   } | ||||
|  | ||||
|   template <class ivalue_t> | ||||
|   void enqueue_ivalue(const ivalue_t& t) { | ||||
|     stack.emplace_back(t); | ||||
|   } | ||||
|   template <class ivalue_t> | ||||
|   void dequeue_ivalue(ivalue_t& value) { | ||||
|     value = stack[idx++].to<ivalue_t>(); | ||||
|   } | ||||
| #define HANDLE_IVALUE(ivalue_t)                            \ | ||||
|   void enqueue(const ivalue_t& value) {                    \ | ||||
|     return enqueue_ivalue<ivalue_t>(value);                \ | ||||
|   }                                                        \ | ||||
|   void enqueue(const std::vector<ivalue_t>& value) {       \ | ||||
|     return enqueue_ivalue<std::vector<ivalue_t>>(value);   \ | ||||
|   }                                                        \ | ||||
|   void enqueue(const c10::optional<ivalue_t>& value) {     \ | ||||
|     return enqueue_ivalue<c10::optional<ivalue_t>>(value); \ | ||||
|   }                                                        \ | ||||
|   void dequeue(ivalue_t& value) {                          \ | ||||
|     return dequeue_ivalue<ivalue_t>(value);                \ | ||||
|   }                                                        \ | ||||
|   void dequeue(std::vector<ivalue_t>& value) {             \ | ||||
|     return dequeue_ivalue<std::vector<ivalue_t>>(value);   \ | ||||
|   }                                                        \ | ||||
|   void dequeue(c10::optional<ivalue_t>& value) {           \ | ||||
|     return dequeue_ivalue<c10::optional<ivalue_t>>(value); \ | ||||
|   } | ||||
|   HANDLE_IVALUE(at::Tensor) | ||||
|   HANDLE_IVALUE(c10::ScalarType) | ||||
|   HANDLE_IVALUE(c10::Scalar) | ||||
|   HANDLE_IVALUE(c10::Layout) | ||||
|   HANDLE_IVALUE(c10::Device) | ||||
|   HANDLE_IVALUE(c10::MemoryFormat) | ||||
|   HANDLE_IVALUE(bool) | ||||
|   HANDLE_IVALUE(double) | ||||
|   HANDLE_IVALUE(std::string) | ||||
| #undef HANDLE_IVALUE | ||||
| }; | ||||
|  | ||||
| } // namespace torch::dynamo::autograd | ||||
|  | ||||
| template <> | ||||
|  | ||||
| @ -52,6 +52,12 @@ Notes: | ||||
| namespace torch::dynamo::autograd { | ||||
| using c10::SymInt; | ||||
|  | ||||
| static PyObject* kPyCompiler; | ||||
|  | ||||
| PyObject* current_py_compiler() { | ||||
|   return kPyCompiler; | ||||
| } | ||||
|  | ||||
| static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) { | ||||
|   PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size())); | ||||
|   for (const auto i : c10::irange(inputs.size())) { | ||||
| @ -89,6 +95,23 @@ static void check(bool result) { | ||||
|     check(nullptr); | ||||
| } | ||||
|  | ||||
| static variable_list validate_outputs( | ||||
|     variable_list& outputs, | ||||
|     const ivalue_list& saved) { | ||||
|   SavedState r; | ||||
|   r.stack = saved; | ||||
|   std::vector<c10::optional<InputMetadata>> value; | ||||
|   r.dequeue(value); | ||||
|  | ||||
|   torch::autograd::validate_outputs( | ||||
|       value, outputs, [&](const std::string& msg) { | ||||
|         std::ostringstream ss; | ||||
|         ss << "[Compiled Autograd Tracing:]" << msg; | ||||
|         return ss.str(); | ||||
|       }); | ||||
|   return outputs; | ||||
| } | ||||
|  | ||||
| // snapshot of python verbose logging toggle | ||||
| static PyObject* python_verbose_logger = nullptr; | ||||
|  | ||||
| @ -495,6 +518,91 @@ void set_ivalue_proxies( | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename Func> | ||||
| static variable_list call_function( | ||||
|     PyObject* py_compiler, | ||||
|     const char* name, | ||||
|     Func fn, | ||||
|     const variable_list& inputs, | ||||
|     const ivalue_list& saved_state, | ||||
|     int64_t num_outputs, | ||||
|     const std::string& debug) { | ||||
|   // Need this to do PyObject* -> IValue conversion | ||||
|   std::vector<at::TypePtr> schema; | ||||
|   schema.reserve(saved_state.size()); | ||||
|   for (const auto& ivalue : saved_state) { | ||||
|     schema.emplace_back(ivalue.type()); | ||||
|   } | ||||
|  | ||||
|   // We are going to bind the following function to Python | ||||
|   auto py_func = py::cpp_function( | ||||
|       [schema, fn]( | ||||
|           std::vector<c10::optional<at::Tensor>>& inputs, | ||||
|           const py::args& args) -> py::object { | ||||
|         // It reconstructs the saved_state from args via the schema | ||||
|         std::vector<at::IValue> stack; | ||||
|         TORCH_INTERNAL_ASSERT(args.size() == schema.size()); | ||||
|         auto tuple_args = jit::tuple_slice(args); | ||||
|         for (uint64_t idx = 0; idx < schema.size(); idx++) { | ||||
|           stack.emplace_back( | ||||
|               jit::toIValue(tuple_args[idx], schema[idx], c10::nullopt)); | ||||
|         } | ||||
|         std::vector<at::Tensor> inputs_; | ||||
|         for (const auto& inp : inputs) { | ||||
|           if (inp.has_value()) { | ||||
|             inputs_.emplace_back(*inp); | ||||
|           } else { | ||||
|             inputs_.emplace_back(); | ||||
|           } | ||||
|         } | ||||
|         auto outputs = fn(inputs_, stack); | ||||
|         return jit::toPyObject(at::IValue(outputs)); | ||||
|       }); | ||||
|  | ||||
|   // convert ivalue_list -> PyObject* | ||||
|   PyObject* py_saved_state = | ||||
|       PyTuple_New(static_cast<Py_ssize_t>(schema.size())); | ||||
|   for (const auto i : c10::irange(schema.size())) { | ||||
|     py::object obj = jit::toPyObject(saved_state[i]); | ||||
|     Py_INCREF(obj.ptr()); | ||||
|     PyTuple_SET_ITEM(py_saved_state, i, obj.ptr()); | ||||
|   } | ||||
|  | ||||
|   // call the corresponding method on the py_compiler | ||||
|   // That method will figure out what to do with the function | ||||
|   // (it can either inline it or plop it straight into the FX graph). | ||||
|   py::handle handle(py_compiler); | ||||
|   py::object stuff = handle.attr(name)( | ||||
|       py_func, inputs, py::handle(py_saved_state), num_outputs, debug); | ||||
|  | ||||
|   // Convert the output from PyObject* to vector<Tensor> | ||||
|   auto tmp = py::cast<std::vector<std::optional<at::Tensor>>>(stuff); | ||||
|   variable_list outputs; | ||||
|   for (const auto& t : tmp) { | ||||
|     if (t.has_value()) { | ||||
|       outputs.emplace_back(t.value()); | ||||
|     } else { | ||||
|       outputs.emplace_back(); | ||||
|     } | ||||
|   } | ||||
|   return outputs; | ||||
| } | ||||
|  | ||||
| static at::Tensor call_accumulate( | ||||
|     PyObject* py_compiler, | ||||
|     const at::Tensor& old_var, | ||||
|     const at::Tensor& new_var) { | ||||
|   if (!old_var.defined()) { | ||||
|     return new_var; | ||||
|   } | ||||
|   if (!new_var.defined()) { | ||||
|     return old_var; | ||||
|   } | ||||
|   py::handle handle(py_compiler); | ||||
|   py::object stuff = handle.attr("accumulate")(old_var, new_var); | ||||
|   return py::cast<at::Tensor>(stuff); | ||||
| } | ||||
|  | ||||
| static TraceState call_begin_capture( | ||||
|     PyObject* self, | ||||
|     CacheNode& cache, | ||||
| @ -648,6 +756,7 @@ CacheNode* _compiled_autograd_impl( | ||||
|     // cache miss, need to capture FX graph | ||||
|     ClosingTHPObjectPtr py_compiler( | ||||
|         check(PyObject_CallNoArgs((the_autograd_compiler)))); | ||||
|     kPyCompiler = py_compiler.get(); | ||||
|  | ||||
|     TraceState state = call_begin_capture( | ||||
|         py_compiler, *cache, compiler_call, output_edges.size()); | ||||
| @ -714,17 +823,42 @@ CacheNode* _compiled_autograd_impl( | ||||
|       } | ||||
|  | ||||
|       SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); | ||||
|       variable_list outputs = call.node->apply_with_saved(inputs, saved); | ||||
|       auto saved_state = call.node->retrieve_saved(saved); | ||||
|       // std::cout << call.node->name() << std::endl; | ||||
|       // std::cout << saved_state.size() << std::endl; | ||||
|       // for (const auto& ivalue: saved_state) { | ||||
|       //   if (ivalue.isTensor()) { | ||||
|       //     std::cout << "tensor" << std::endl; | ||||
|       //   } else { | ||||
|       //     ivalue.dump(); | ||||
|       //   } | ||||
|       // } | ||||
|       auto outputs = call_function( | ||||
|           py_compiler, | ||||
|           "apply_functional", | ||||
|           call.node->get_functional(), | ||||
|           inputs, | ||||
|           saved_state, | ||||
|           call.node->num_outputs(), | ||||
|           call.node->name()); | ||||
|       // variable_list outputs = call.node->apply_with_saved(inputs, saved); | ||||
|  | ||||
|       saved.debug_asserts(); | ||||
|       saved.before(call.node->next_edges()); | ||||
|       validate_outputs( | ||||
|           call.node->next_edges(), outputs, [&](const std::string& msg) { | ||||
|             std::ostringstream ss; | ||||
|             ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " | ||||
|                << msg; | ||||
|             return ss.str(); | ||||
|           }); | ||||
|  | ||||
|       auto input_metadata = collect_input_metadata(call.node->next_edges()); | ||||
|       SavedState state; | ||||
|       state.enqueue(input_metadata); | ||||
|       ivalue_list& input_metadata_state = state.stack; | ||||
|       outputs = call_function( | ||||
|           py_compiler, | ||||
|           "validate_outputs", | ||||
|           validate_outputs, | ||||
|           outputs, | ||||
|           input_metadata_state, | ||||
|           outputs.size(), | ||||
|           "validate_outputs"); | ||||
|  | ||||
|       saved.after(call.node->next_edges()); | ||||
|       saved.debug_asserts(); | ||||
|  | ||||
| @ -746,13 +880,14 @@ CacheNode* _compiled_autograd_impl( | ||||
|         auto& output = outputs[i]; | ||||
|         const auto& next = call.node->next_edge(i); | ||||
|         if (next.is_valid() && output.defined()) { | ||||
|           input_buffers.lookup(next.function.get()) | ||||
|               .add( | ||||
|                   next.input_nr, std::move(output), std::nullopt, std::nullopt); | ||||
|           auto& buffer = input_buffers.lookup(next.function.get()); | ||||
|           buffer.buffer[next.input_nr] = call_accumulate( | ||||
|               py_compiler, buffer.buffer[next.input_nr], output); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     kPyCompiler = nullptr; | ||||
|     PyObject* res = check(call_end_capture(py_compiler, state.outputs)); | ||||
|     TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); | ||||
|     TORCH_CHECK( | ||||
|  | ||||
| @ -4,4 +4,5 @@ | ||||
| // see [Note: Compiled Autograd] | ||||
| namespace torch::dynamo::autograd { | ||||
| PyObject* torch_c_dynamo_compiled_autograd_init(); | ||||
| PyObject* current_py_compiler(); | ||||
| } // namespace torch::dynamo::autograd | ||||
|  | ||||
| @ -369,8 +369,18 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional<int32_t> N) { | ||||
|           } | ||||
|         case TypeKind::BoolType: | ||||
|           return IValue(py::cast<std::vector<bool>>(obj)); | ||||
|         case TypeKind::TensorType: | ||||
|           return IValue(py::cast<std::vector<at::Tensor>>(obj)); | ||||
|         case TypeKind::TensorType: { | ||||
|           auto thing = py::cast<std::vector<std::optional<at::Tensor>>>(obj); | ||||
|           auto thing2 = std::vector<at::Tensor>(); | ||||
|           for (const auto& inp : thing) { | ||||
|             if (inp.has_value()) { | ||||
|               thing2.emplace_back(*inp); | ||||
|             } else { | ||||
|               thing2.emplace_back(); | ||||
|             } | ||||
|           } | ||||
|           return IValue(thing2); | ||||
|         } | ||||
|         default: | ||||
|           return createGenericList(obj, elem_type); | ||||
|       } | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	