mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 16:44:54 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			ciflow/ind
			...
			python_com
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e5937dc68c | 
							
								
								
									
										140
									
								
								r2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								r2.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,140 @@ | ||||
| # type: ignore | ||||
| import torch | ||||
| import torch.utils.cpp_extension | ||||
|  | ||||
| def compiler_fn(gm): | ||||
|     # return gm | ||||
|     return torch.compile(gm, backend="eager", fullgraph=False) | ||||
|  | ||||
| # =========================================================== | ||||
| # Basic test with a hook that has side effects | ||||
|  | ||||
|  | ||||
| # Test case 1: a hook | ||||
| x = torch.tensor([1., 2., 3.], requires_grad=True) | ||||
| y = x ** 2 | ||||
| z = y.sum() | ||||
|  | ||||
| im_grad = [] | ||||
|  | ||||
| def hook(grad): | ||||
|     im_grad.append(grad) | ||||
|     return 2 * grad | ||||
|  | ||||
| y.register_hook(hook) | ||||
|  | ||||
| with torch._dynamo.compiled_autograd.enable(compiler_fn): | ||||
|     z.backward() | ||||
|  | ||||
| assert torch.allclose(x.grad, 4 * x) | ||||
| assert torch.allclose(im_grad[0], torch.ones_like(y)) | ||||
|  | ||||
| # =========================================================== | ||||
| # Unsupported C++ autograd node should graph break. | ||||
| # This is better than the current compiled autograd behavior of "error out" | ||||
| # and brings us a step closer to having "compiled autograd on by default". | ||||
| # In theory we can also add a config that automatically treats | ||||
| # it as an opaque callable, but such a config is unsound. | ||||
|  | ||||
| cpp_source = """ | ||||
| struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { | ||||
|   static constexpr bool is_traceable = false; | ||||
|   static torch::Tensor forward( | ||||
|       torch::autograd::AutogradContext* ctx, | ||||
|       const torch::Tensor& x) { | ||||
|     return x; | ||||
|   } | ||||
|  | ||||
|   static torch::autograd::variable_list backward( | ||||
|       torch::autograd::AutogradContext *ctx, | ||||
|       torch::autograd::variable_list grad_output) { | ||||
|     // not traceable | ||||
|     *grad_output[0].data_ptr<float>() = 3.14; | ||||
|     return grad_output; | ||||
|   } | ||||
| }; | ||||
| torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { | ||||
|   return CustomOpAutogradFunction::apply(x); | ||||
| } | ||||
| TORCH_LIBRARY_FRAGMENT(mylib, m) { | ||||
|     m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); | ||||
| } | ||||
| """ | ||||
|  | ||||
| module = torch.utils.cpp_extension.load_inline( | ||||
|     name="mylib", | ||||
|     cpp_sources=cpp_source, | ||||
|     functions="custom_op_backed_by_autograd_fn", | ||||
|     verbose=True, | ||||
| ) | ||||
|  | ||||
| x = torch.ones(2, 2, requires_grad=True) | ||||
| out = torch.ops.mylib.custom_op_backed_by_autograd_fn( | ||||
|     x | ||||
| ) | ||||
| loss = out.sum() | ||||
| with torch._dynamo.compiled_autograd.enable(compiler_fn): | ||||
|     loss.backward() | ||||
|  | ||||
| expected = torch.ones_like(x) * 3.14 | ||||
| assert torch.allclose(x.grad, expected) | ||||
|  | ||||
| # =========================================================== | ||||
| # Tests that we don't bake in "guessed" metadata. | ||||
| # This test case would have erroed out in the previous | ||||
| # compiled autograd. | ||||
|  | ||||
| import torch | ||||
| import torch.utils.cpp_extension | ||||
| cpp_source2 = """ | ||||
| struct CustomOpAutogradFunction2 : public torch::autograd::Function<CustomOpAutogradFunction2> { | ||||
|   static constexpr bool is_traceable = true; | ||||
|   static torch::Tensor forward( | ||||
|       torch::autograd::AutogradContext* ctx, | ||||
|       const torch::Tensor& x) { | ||||
|     return x; | ||||
|   } | ||||
|   static torch::autograd::variable_list backward( | ||||
|       torch::autograd::AutogradContext *ctx, | ||||
|       torch::autograd::variable_list grad_output) { | ||||
|     if (grad_output[0].is_contiguous()) { | ||||
|         return {2 * grad_output[0]}; | ||||
|     } else { | ||||
|         return {3 * grad_output[0]}; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
| torch::Tensor custom_op_backed_by_autograd_fn2(torch::Tensor x) { | ||||
|   return CustomOpAutogradFunction2::apply(x); | ||||
| } | ||||
| TORCH_LIBRARY_FRAGMENT(mylib, m) { | ||||
|     m.def("custom_op_backed_by_autograd_fn2", custom_op_backed_by_autograd_fn2); | ||||
| } | ||||
| """ | ||||
|  | ||||
| module = torch.utils.cpp_extension.load_inline( | ||||
|     name="mylib", | ||||
|     cpp_sources=cpp_source2, | ||||
|     functions="custom_op_backed_by_autograd_fn2", | ||||
|     verbose=True, | ||||
| ) | ||||
|  | ||||
|  | ||||
| x = torch.tensor([[1., 2., 3.], [4, 5, 6]], requires_grad=True) | ||||
| y = torch.ops.mylib.custom_op_backed_by_autograd_fn2(x) | ||||
| z = y.clone() | ||||
| w = z.sum() | ||||
|  | ||||
| def hook(grad): | ||||
|     # return a contiguous grad. | ||||
|     # The previous compiled autograd would have "guessed" that | ||||
|     # the tensor is not contiguous. | ||||
|     assert not grad.is_contiguous() | ||||
|     return grad.contiguous() | ||||
|  | ||||
| z.register_hook(hook) | ||||
|  | ||||
| with torch._dynamo.compiled_autograd.enable(lambda x: x): | ||||
|     w.backward() | ||||
|  | ||||
| assert torch.allclose(x.grad, 2 * torch.ones_like(x)) | ||||
							
								
								
									
										190
									
								
								torch/_compiled_autograd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								torch/_compiled_autograd.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,190 @@ | ||||
| # type: ignore | ||||
| import threading | ||||
| import torch | ||||
| from ._compile import _disable_dynamo | ||||
| from ._C import _autograd | ||||
| # TODO(rzou): why doesn't torch.fx.wrap work directly? | ||||
| from torch.fx._symbolic_trace import _create_wrapped_func as wrap | ||||
|  | ||||
| """ | ||||
| TODO(rzou): did we really need a new file? I did it to appease trace_rules. | ||||
| """ | ||||
|  | ||||
|  | ||||
| def python_autograd(saved_state, hooks, nodecalls, num_outputs, arange): | ||||
|     """Given the state of the autograd graph (the saved tensors/sizes/scalar, | ||||
|     hooks, and the actual nodes), execute it in Python. | ||||
|  | ||||
|     Compiled Autograd uses the equivalent of torch.fx.symbolic_trace over | ||||
|     this function to produce a graph that can then be Dynamo'ed. | ||||
|  | ||||
|     NB: Before executing this function (or an acquired graph version of it) | ||||
|     on real Tensors, please call set_global_nodecalls(nodecalls) to set the | ||||
|     current autograd nodes structure state. We intentionally hide this state | ||||
|     from the graph so that Dynamo doesn't need to deal with proxying it into | ||||
|     the graph. | ||||
|  | ||||
|     TODO(rzou): Compiled Autograd is responsible for calling set_global_nodecalls | ||||
|     using the current nodecalls data structure. If the user did not specify | ||||
|     retain_graph=True, then something needs to free it later, | ||||
|     so we don't end up keeping the nodes around forever. | ||||
|     """ | ||||
|     node_to_idx_data = {node_id(call.node): idx for idx, call in enumerate(nodecalls)} | ||||
|  | ||||
|     def node_to_idx(node): | ||||
|         return node_to_idx_data[torch._compiled_autograd.node_id(node)] | ||||
|  | ||||
|     input_buffers = {} | ||||
|  | ||||
|     def lookup_input_buffer(node_idx, num_inputs): | ||||
|         if node_idx not in input_buffers: | ||||
|             input_buffers[node_idx] = [None] * num_inputs | ||||
|         return input_buffers[node_idx] | ||||
|  | ||||
|     saved_state = iter(SavedState( | ||||
|         nodecalls, | ||||
|         saved_state[0], | ||||
|         saved_state[1], | ||||
|         saved_state[2], | ||||
|     )) | ||||
|  | ||||
|     graph_outputs = [None] * num_outputs | ||||
|  | ||||
|     for idx, call in enumerate(nodecalls): | ||||
|         node_idx = arange[idx] | ||||
|         inputs = lookup_input_buffer(idx, call.node.num_inputs()) | ||||
|  | ||||
|         # Given all of the saved state, retrieve the saved state that matters | ||||
|         # for the current node call. | ||||
|         apply_state, validate_outputs_state = next(saved_state) | ||||
|  | ||||
|         for hook_idx, input_idx in call.tensor_pre_hooks: | ||||
|             inputs[input_idx] = call_hook(hooks[hook_idx], inputs[input_idx], hook_type="pre_hook") | ||||
|         for input_nr, result_idx in call.graph_output: | ||||
|             graph_outputs[result_idx] = inputs[input_nr] | ||||
|         if not call.needed: | ||||
|             continue | ||||
|         if call.node.is_compiled_autograd_traceable(): | ||||
|             outputs = apply_with_saved(node_idx, inputs, *apply_state) | ||||
|         else: | ||||
|             outputs = apply_with_saved_dynamo_disabled(node_idx, inputs, *apply_state) | ||||
|         outputs = validate_outputs(node_idx, outputs, *validate_outputs_state) | ||||
|         for hook_idx, input_idx in call.post_hooks: | ||||
|             call_hook(hooks[hook_idx], outputs, inputs, hook_type="post_hook") | ||||
|         for output_idx in range(call.node.num_outputs()): | ||||
|             output = outputs[output_idx] | ||||
|             next_edge = call.node.next_edge(output_idx) | ||||
|             if not next_edge.is_valid(): | ||||
|                 continue | ||||
|             next_node = next_edge.function | ||||
|             input_buffer = lookup_input_buffer(node_to_idx(next_node), next_node.num_inputs()) | ||||
|             updated_buffer = accumulate(input_buffer[next_edge.input_nr], output) | ||||
|             input_buffer[next_edge.input_nr] = updated_buffer | ||||
|  | ||||
|     return graph_outputs | ||||
|  | ||||
|  | ||||
| global_nodecalls = threading.local() | ||||
|  | ||||
|  | ||||
| def get_node(idx): | ||||
|     return global_nodecalls.thread_local[idx].node | ||||
|  | ||||
|  | ||||
| def set_global_nodecalls(nodecalls): | ||||
|     global_nodecalls.thread_local = nodecalls | ||||
|  | ||||
|  | ||||
| @wrap | ||||
| def apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars): | ||||
|     """ | ||||
|     Applies the node at global_nodecalls[node_idx] using the inputs and saved values. | ||||
|     """ | ||||
|     node = get_node(node_idx) | ||||
|     outputs = _autograd.apply_with_saved(global_nodecalls.thread_local[node_idx], inputs, saved_tensors, list(saved_sizes), saved_scalars) | ||||
|     return outputs | ||||
|  | ||||
|  | ||||
| @_disable_dynamo | ||||
| @wrap | ||||
| def apply_with_saved_dynamo_disabled(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars): | ||||
|     """ | ||||
|     This is apply_with_saved, but also induces a graph break in Dynamo. | ||||
|     """ | ||||
|     return apply_with_saved(node_idx, inputs, saved_tensors, saved_sizes, saved_scalars) | ||||
|  | ||||
|  | ||||
| @wrap | ||||
| def validate_outputs(node_idx, outputs, saved_tensors, saved_sizes, saved_scalars): | ||||
|     """ | ||||
|     Validates the outputs of the node at global_nodecalls[node_idx]. This requires | ||||
|     swizzling out some input metadata state of the next nodes, which is why | ||||
|     it also accepts some saved variables. | ||||
|     """ | ||||
|     outputs = _autograd.validate_outputs_with_saved(global_nodecalls.thread_local[node_idx], outputs, saved_tensors, list(saved_sizes), saved_scalars) | ||||
|     return outputs | ||||
|  | ||||
|  | ||||
| def node_id(node): | ||||
|     if node is None: | ||||
|         breakpoint() | ||||
|     assert node is not None | ||||
|     return _autograd.node_id(node) | ||||
|  | ||||
|  | ||||
| def arange(num): | ||||
|     return list(range(num)) | ||||
|  | ||||
|  | ||||
| @wrap | ||||
| def call_hook(*args, **kwargs): | ||||
|     return torch._dynamo.external_utils.call_hook(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| class IterableWrapper: | ||||
|     def __init__(self, noniterable, size): | ||||
|         self.noniterable = noniterable | ||||
|         self.idx = 0 | ||||
|         self.size = size | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|         assert self.idx < self.size | ||||
|         result = self.noniterable[self.idx] | ||||
|         self.idx += 1 | ||||
|         return result | ||||
|  | ||||
|  | ||||
| class SavedState: | ||||
|     def __init__(self, nodecalls, tensors, sizes, scalars): | ||||
|         self.tensors = tensors | ||||
|         self.sizes = sizes | ||||
|         self.scalars = scalars | ||||
|         self.nodecalls = iter(nodecalls) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|         call = next(self.nodecalls) | ||||
|  | ||||
|         def get_next(collection_info): | ||||
|             tensors = [next(self.tensors) for _ in range(collection_info.num_saved_tensors)] | ||||
|             sizes = [next(self.sizes) for _ in range(collection_info.num_saved_sizes)] | ||||
|             scalars = [next(self.scalars) for _ in range(collection_info.num_saved_ivalues)] | ||||
|             return (tensors, sizes, scalars) | ||||
|  | ||||
|         saved_state_for_apply = get_next(call.compiled_args_info) | ||||
|         saved_state_for_validate_output = get_next(call.next_edges_info) | ||||
|         return saved_state_for_apply, saved_state_for_validate_output | ||||
|  | ||||
|  | ||||
| @wrap | ||||
| def accumulate(old_var, var): | ||||
|     if old_var is None: | ||||
|         return var | ||||
|     if var is None: | ||||
|         return old_var | ||||
|     return old_var + var | ||||
| @ -82,6 +82,49 @@ class AutogradCompilerInstance: | ||||
|     def source(name, idx) -> GetItemSource: | ||||
|         return GetItemSource(LocalSource(name), idx) | ||||
|  | ||||
|     def capture(self, tensors, sizes, scalars, origins, nodecalls, num_outputs): | ||||
|         dynamic_sizes = tuple(s for s in sizes if s is not None) | ||||
|  | ||||
|         counters["compiled_autograd"]["captures"] += 1 | ||||
|         inputs_origins, sizes_origins, scalars_origins = origins | ||||
|  | ||||
|         self.fx_tracer.root = torch.nn.Module() | ||||
|         self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) | ||||
|         self.fx_tracer.tensor_attrs = {} | ||||
|         inputs_proxy, dynamic_sizes_proxy, scalars_proxy, self.hooks_proxy = ( | ||||
|             self.fx_tracer.create_proxy("placeholder", name, (), {}) | ||||
|             for name in self.graph_placeholders | ||||
|         ) | ||||
|  | ||||
|         sizes_proxy = [None] * len(sizes) | ||||
|         dynamic_sizes_next = 0 | ||||
|         for idx in range(len(sizes)): | ||||
|             if sizes[idx] is not None: | ||||
|                 sizes_proxy[idx] = dynamic_sizes[dynamic_sizes_next] | ||||
|                 dynamic_sizes_next += 1 | ||||
|  | ||||
|         from torch._compiled_autograd import IterableWrapper, python_autograd, arange | ||||
|  | ||||
|         arange_proxy = self.fx_tracer.create_proxy( | ||||
|             kind="call_function", | ||||
|             target=arange, | ||||
|             args=(len(nodecalls),), | ||||
|             kwargs={} | ||||
|         ) | ||||
|  | ||||
|         graph_outputs = python_autograd( | ||||
|             ( | ||||
|                 IterableWrapper(inputs_proxy, len(tensors)), | ||||
|                 IterableWrapper(sizes_proxy, len(sizes)), | ||||
|                 IterableWrapper(scalars_proxy, len(scalars)), | ||||
|             ), | ||||
|             self.hooks_proxy, | ||||
|             nodecalls, | ||||
|             num_outputs, | ||||
|             arange_proxy, | ||||
|         ) | ||||
|         return self.end_capture(graph_outputs) | ||||
|  | ||||
|     def begin_capture( | ||||
|         self, | ||||
|         inputs: List[torch.Tensor], | ||||
| @ -308,8 +351,10 @@ class AutogradCompilerInstance: | ||||
|             (self.fx_tracer.create_arg(self.to_proxy(outputs)),), | ||||
|             {}, | ||||
|         ) | ||||
|         self.rename_aot_dispatcher_nodes() | ||||
|         self.reorder_accumulate_grad_nodes() | ||||
|         # TODO(rzou): we didn't inline the AOTDispatcher nodes | ||||
|         # self.rename_aot_dispatcher_nodes() | ||||
|         # TODO(rzou): we need to transform AccumulateGrad nodes into torch.inductor.accumulate_grad_. | ||||
|         # self.reorder_accumulate_grad_nodes() | ||||
|         runtime_inputs_to_move: List[int] = [] | ||||
|         if snapshot_cudagraph_enabled(): | ||||
|             runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) | ||||
| @ -317,6 +362,7 @@ class AutogradCompilerInstance: | ||||
|         graph = GraphModule( | ||||
|             self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" | ||||
|         ) | ||||
|         graph.print_readable() | ||||
|         set_locals_to_steal(graph, ["inputs"]) | ||||
|         lazy_graph_code = lazy_format_graph_code( | ||||
|             "Compiled autograd graph", | ||||
| @ -562,3 +608,5 @@ def reset() -> None: | ||||
|     assert not in_compiled_autograd_region | ||||
|     torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) | ||||
|     torch._C._dynamo.compiled_autograd.set_verbose_logger(None) | ||||
|  | ||||
| from torch._compiled_autograd import set_global_nodecalls | ||||
|  | ||||
| @ -950,6 +950,9 @@ class OutputGraph: | ||||
|             list_name = arg.source.local_name | ||||
|             assert list_name in self.code_options["co_varnames"] | ||||
|             for x in needs_alias[list_name]: | ||||
|                 if not hasattr(x.source, "index"): | ||||
|                     # TODO(rzou): idk | ||||
|                     breakpoint() | ||||
|                 list_idx = x.source.index | ||||
|                 if list_idx not in visited: | ||||
|                     alias_name = self.new_var( | ||||
|  | ||||
| @ -134,6 +134,13 @@ If you are removing an existing torch level API: | ||||
|  | ||||
| """ | ||||
| manual_torch_name_rule_map = { | ||||
|     "torch._compiled_autograd.CA_apply_with_saved": TorchInGraphFunctionVariable, | ||||
|     "torch._compiled_autograd.accumulate2": TorchInGraphFunctionVariable, | ||||
|     "torch._compiled_autograd.CA_validate_outputs": TorchInGraphFunctionVariable, | ||||
|     # "torch._compiled_autograd.CA_apply_with_saved_dynamo_disabled": TorchInGraphFunctionVariable, | ||||
|     "torch._compiled_autograd.CA_update_input_buffers": TorchInGraphFunctionVariable, | ||||
|     "torch._compiled_autograd.CA_input_buffers_init": TorchInGraphFunctionVariable, | ||||
|     "torch._compiled_autograd.CA_input_buffers_lookup": TorchInGraphFunctionVariable, | ||||
|     "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, | ||||
|     "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, | ||||
|     "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, | ||||
| @ -3237,6 +3244,7 @@ if torch.distributed.is_available(): | ||||
| # We are using python module name instead of file or directory object to avoid circular dependency. | ||||
| # Please keep this sorted alphabetically. | ||||
| MOD_INLINELIST = [ | ||||
|     "torch._compiled_autograd", | ||||
|     "torch._decomp", | ||||
|     "torch._dynamo._trace_wrapped_higher_order_op", | ||||
|     "torch._dynamo.comptime", | ||||
|  | ||||
| @ -1219,7 +1219,12 @@ class VariableBuilder: | ||||
|         maybe_gm = self.tx.output.local_scope.get("self") | ||||
|         if isinstance( | ||||
|             self.source, LocalSource | ||||
|         ) and self.source.local_name in get_locals_to_steal(maybe_gm): | ||||
|         # TODO(rzou): We changed compiled autograd to pass all of the inputs saved | ||||
|         # instead of a de-duplicated list. Unfortunately that makes the input | ||||
|         # stealing logic go haywire. We can either fix it or figure out | ||||
|         # how to deal with a de-duplicated list (the problem is | ||||
|         # mapping the de-duplicated saved tensors back to the nodes that need them). | ||||
|         ) and self.source.local_name in get_locals_to_steal(maybe_gm) and False: | ||||
|             # The input tensor list to dynamo from compiled autograd may contain activations | ||||
|             # which are freed as they are used in inductor. Dynamo's default behavior is to | ||||
|             # lift all tensors to the graph inputs, but this will cause dynamo to hold an | ||||
| @ -1249,13 +1254,17 @@ class VariableBuilder: | ||||
|                 source_i = GetItemSource(base=source, index=i, index_is_slice=False) | ||||
|                 # access unpacked tensor from this list instead of from a lifted arg | ||||
|                 self.tx.output.input_source_to_var[source_i] = tensor_variable | ||||
|                 tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( | ||||
|                     value[i] | ||||
|                 ) | ||||
|                 if isinstance(tensor_variable, TensorVariable): | ||||
|                     tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( | ||||
|                         value[i] | ||||
|                     ) | ||||
|  | ||||
|                 guard = functools.partial( | ||||
|                     GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) | ||||
|                 ) | ||||
|                     guard = functools.partial( | ||||
|                         GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) | ||||
|                     ) | ||||
|                 else: | ||||
|                     # TODO(rzou): None guard? | ||||
|                     pass | ||||
|                 guards.append(source_i.make_guard(guard)) | ||||
|  | ||||
|             install_guard(*guards, skip=1) | ||||
|  | ||||
| @ -188,16 +188,25 @@ struct CppNode : public Node { | ||||
|   void set_ctx_grad_fn(const std::shared_ptr<Node>& node); | ||||
|   void save_variables_to_ctx(); | ||||
|  | ||||
|   bool is_compiled_autograd_traceable() override { | ||||
|     static_assert( | ||||
|         std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>); | ||||
|     return T::is_traceable; | ||||
|   } | ||||
|  | ||||
|   void compiled_args(CompiledNodeArgs& args) override { | ||||
|     static_assert( | ||||
|         std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>); | ||||
|     if (!T::is_traceable) { | ||||
|       throw std::runtime_error( | ||||
|           std::string( | ||||
|               "Attempting to trace a potentially unsafe C++ autograd function: ") + | ||||
|           name() + | ||||
|           ". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/."); | ||||
|     } | ||||
|     // if (!T::is_traceable) { | ||||
|     //   throw std::runtime_error( | ||||
|     //       std::string( | ||||
|     //           "Attempting to trace a potentially unsafe C++ autograd | ||||
|     //           function: ") + | ||||
|     //       name() + | ||||
|     //       ". It may be possible to trace it safely, please refer to the | ||||
|     //       instructions in: | ||||
|     //       https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/."); | ||||
|     // } | ||||
|  | ||||
|     // although neither of the 2 methods below have uniqueness guarantees | ||||
|     // it is unlikely for them to collide at the same time | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| #include <c10/util/ThreadLocal.h> | ||||
| #include <torch/csrc/autograd/engine.h> | ||||
| #include <torch/csrc/autograd/variable.h> | ||||
| #include <torch/csrc/dynamo/compiled_autograd.h> | ||||
|  | ||||
| #include <ATen/ATen.h> | ||||
|  | ||||
|  | ||||
| @ -563,6 +563,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | ||||
|   /// release variables as they run. | ||||
|   virtual void will_release_variables() {} | ||||
|  | ||||
|   virtual bool is_compiled_autograd_traceable() { | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   /// Returns true if this function is traceable. An op is traceable if all | ||||
|   /// operations happening within `apply()` are performed on autograd | ||||
|   /// `Variables` (i.e. apply mostly instantiates and applies other functions). | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| #include <torch/csrc/python_headers.h> | ||||
| #include <memory> | ||||
|  | ||||
| #include <ATen/PythonTorchFunctionTLS.h> | ||||
| #include <ATen/SavedTensorHooks.h> | ||||
| @ -14,6 +15,7 @@ | ||||
| #include <torch/csrc/autograd/VariableTypeUtils.h> | ||||
| #include <torch/csrc/autograd/autograd.h> | ||||
| #include <torch/csrc/autograd/autograd_not_implemented_fallback.h> | ||||
| #include <torch/csrc/autograd/engine.h> | ||||
| #include <torch/csrc/autograd/function.h> | ||||
| #include <torch/csrc/autograd/grad_mode.h> | ||||
| #include <torch/csrc/autograd/input_metadata.h> | ||||
| @ -26,6 +28,7 @@ | ||||
| #include <torch/csrc/autograd/saved_variable.h> | ||||
| #include <torch/csrc/autograd/utils/python_arg_parsing.h> | ||||
| #include <torch/csrc/autograd/utils/wrap_outputs.h> | ||||
| #include <torch/csrc/dynamo/compiled_autograd.h> | ||||
| #include <torch/csrc/jit/python/pybind_utils.h> | ||||
| #include <torch/csrc/profiler/collection.h> | ||||
| #include <torch/csrc/profiler/kineto_shim.h> | ||||
| @ -42,6 +45,7 @@ | ||||
|  | ||||
| using torch::impl::py_context_manager; | ||||
| using torch::impl::py_context_manager_DEPRECATED; | ||||
| using namespace torch::dynamo::autograd; | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| @ -79,6 +83,55 @@ struct EnablePythonDispatcher { | ||||
|   c10::impl::PyInterpreter* old_; | ||||
| }; | ||||
|  | ||||
| std::vector<at::Tensor> toVec( | ||||
|     const std::vector<std::optional<at::Tensor>>& ts) { | ||||
|   std::vector<at::Tensor> result; | ||||
|   for (const auto& opt_tensor : ts) { | ||||
|     if (opt_tensor.has_value()) { | ||||
|       result.push_back(opt_tensor.value()); | ||||
|     } else { | ||||
|       result.emplace_back(); | ||||
|     } | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| variable_list validate_outputs_with_saved( | ||||
|     const NodeCall& nodecall, | ||||
|     std::vector<at::Tensor>& outputs, | ||||
|     const std::vector<at::Tensor>& saved_tensors, | ||||
|     const std::vector<std::optional<at::SymInt>>& saved_sizes, | ||||
|     const std::vector<at::IValue>& saved_ivalues) { | ||||
|   auto saved = SwapSavedVariables( | ||||
|       saved_tensors, saved_sizes, saved_ivalues, nullptr, nodecall); | ||||
|   saved.before(nodecall.node->next_edges()); | ||||
|   torch::autograd::validate_outputs( | ||||
|       nodecall.node->next_edges(), outputs, [&](const std::string& msg) { | ||||
|         std::ostringstream ss; | ||||
|         ss << "[Compiled Autograd Tracing: " << nodecall.node->name() << "] " | ||||
|            << msg; | ||||
|         return ss.str(); | ||||
|       }); | ||||
|   saved.after(nodecall.node->next_edges()); | ||||
|   return outputs; | ||||
| } | ||||
|  | ||||
| variable_list apply_with_saved314( | ||||
|     const NodeCall& nodecall, | ||||
|     const std::vector<std::optional<at::Tensor>>& inputs, | ||||
|     const std::vector<std::optional<at::Tensor>>& saved_tensors, | ||||
|     const std::vector<std::optional<at::SymInt>>& saved_sizes, | ||||
|     const std::vector<at::IValue>& saved_ivalues) { | ||||
|   auto saved = SwapSavedVariables( | ||||
|       toVec(saved_tensors), saved_sizes, saved_ivalues, nullptr, nodecall); | ||||
|   auto outputs = nodecall.node->apply_with_saved(toVec(inputs), saved); | ||||
|   return outputs; | ||||
| } | ||||
|  | ||||
| uint64_t node_id(const std::shared_ptr<Node>& node) { | ||||
|   return reinterpret_cast<uint64_t>(node.get()); | ||||
| } | ||||
|  | ||||
| struct EnablePreDispatch { | ||||
|   EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {} | ||||
|   c10::impl::IncludeDispatchKeyGuard guard_; | ||||
| @ -491,6 +544,50 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // compiled_autograd stuff | ||||
|   py::class_<torch::autograd::Node, std::shared_ptr<torch::autograd::Node>>( | ||||
|       m, "Node") | ||||
|       .def("compiled_args", &torch::autograd::Node::compiled_args) | ||||
|       .def("next_edge", &torch::autograd::Node::next_edge) | ||||
|       .def( | ||||
|           "is_compiled_autograd_traceable", | ||||
|           &torch::autograd::Node::is_compiled_autograd_traceable) | ||||
|       .def("name", &torch::autograd::Node::name) | ||||
|       .def("num_outputs", &torch::autograd::Node::num_outputs) | ||||
|       .def("num_inputs", &torch::autograd::Node::num_inputs); | ||||
|   py::class_<torch::autograd::Edge>(m, "Edge") | ||||
|       .def("is_valid", &torch::autograd::Edge::is_valid) | ||||
|       .def_readonly("input_nr", &torch::autograd::Edge::input_nr) | ||||
|       .def_readonly("function", &torch::autograd::Edge::function); | ||||
|   py::class_<CollectionInfo>(m, "CollectionInfo") | ||||
|       .def_readonly("num_saved_tensors", &CollectionInfo::num_saved_tensors) | ||||
|       .def_readonly("num_saved_sizes", &CollectionInfo::num_saved_sizes) | ||||
|       .def_readonly("num_saved_ivalues", &CollectionInfo::num_saved_ivalues); | ||||
|   py::class_<torch::dynamo::autograd::NodeCall>(m, "NodeCall") | ||||
|       .def_readonly("node", &NodeCall::node) | ||||
|       .def_readonly("compiled_args_info", &NodeCall::compiled_args_info) | ||||
|       .def_readonly("next_edges_info", &NodeCall::next_edges_info) | ||||
|       .def_readonly("tensor_pre_hooks", &NodeCall::tensor_pre_hooks) | ||||
|       .def_readonly("post_hooks", &NodeCall::post_hooks) | ||||
|       .def_readonly("graph_output", &NodeCall::graph_output) | ||||
|       .def_readonly("needed", &NodeCall::needed); | ||||
|   py::class_<torch::dynamo::autograd::CompiledNodeArgs>(m, "CompiledNodeArgs") | ||||
|       .def(py::init<AutogradCompilerCall&, NodeCall&>()); | ||||
|   py::class_<torch::dynamo::autograd::AutogradCompilerCall>( | ||||
|       m, "AutogradCompilerCall") | ||||
|       .def(py::init<>()); | ||||
|   m.def("apply_with_saved", &apply_with_saved314); | ||||
|   m.def("validate_outputs_with_saved", &validate_outputs_with_saved); | ||||
|   m.def("node_id", &node_id); | ||||
|   // py::class_<SwapInterface,PySwapInterface>(m, "SwapInterface"); | ||||
|   //  py::class_<SwapWithReal,SwapInterface>(m, "SwapWithReal") | ||||
|   //    .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>>()) | ||||
|   //    ; | ||||
|   //  py::class_<SwapSavedVariables>(m, "SwapSavedVariables") | ||||
|   //    .def(py::init<std::vector<at::Tensor>,std::vector<c10::SymInt>,std::vector<c10::IValue>,PyObject*,const | ||||
|   //    NodeCall&>()) | ||||
|   //    ; | ||||
|  | ||||
|   _C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); }); | ||||
|  | ||||
|   py_context_manager_DEPRECATED<c10::InferenceMode, bool>( | ||||
|  | ||||
| @ -69,7 +69,15 @@ struct CacheKey { | ||||
|   const uint8_t* key; | ||||
| }; | ||||
|  | ||||
| struct NodeCall { | ||||
| struct CollectionInfo { | ||||
|   int num_saved_tensors = 0; | ||||
|   int num_saved_sizes = 0; | ||||
|   int num_saved_ivalues = 0; | ||||
| }; | ||||
|  | ||||
| enum CollectionMode { COMPILED_ARGS, NEXT_EDGES }; | ||||
|  | ||||
| struct TORCH_API NodeCall { | ||||
|   NodeCall(uint32_t id_, std::shared_ptr<Node> node_) | ||||
|       : id(id_), node(std::move(node_)) {} | ||||
|  | ||||
| @ -84,6 +92,24 @@ struct NodeCall { | ||||
|   std::vector<int> post_hooks; | ||||
|   std::vector<int> post_acc_grad_hooks; | ||||
|   std::vector<std::pair<int, int>> graph_output; | ||||
|  | ||||
|   CollectionInfo& collection_info() { | ||||
|     if (mode == CollectionMode::NEXT_EDGES) { | ||||
|       return next_edges_info; | ||||
|     } else { | ||||
|       return compiled_args_info; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Given the full list of saved arguments (saved tensors, saved sizes, | ||||
|   // saved scalars), we want to be able to map them back to which node | ||||
|   // they came from. | ||||
|   // The way we do this is that we store information on how many | ||||
|   // tensors/sizes/scalars each Node uses. | ||||
|   CollectionMode mode = CollectionMode::COMPILED_ARGS; | ||||
|   CollectionInfo compiled_args_info; | ||||
|   CollectionInfo next_edges_info; | ||||
|  | ||||
|   bool needed = true; | ||||
| }; | ||||
|  | ||||
| @ -143,9 +169,9 @@ struct TensorArgs { | ||||
|     auto impl = tensor.unsafeGetTensorImpl(); | ||||
|     auto it = _args.find(impl); | ||||
|     if (it == _args.end()) { | ||||
|       TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1); | ||||
|       // TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1); | ||||
|       it = _args.emplace(impl, TensorArg(_next_id++)).first; | ||||
|       inputs.emplace_back(tensor); | ||||
|       // inputs.emplace_back(tensor); | ||||
|       if (active_node_call_idx.has_value()) { | ||||
|         input_origins.emplace_back(active_node_call_idx.value()); | ||||
|       } | ||||
| @ -160,6 +186,9 @@ struct TensorArgs { | ||||
|   } | ||||
|  | ||||
|   TensorArg& add(const at::Tensor& tensor) { | ||||
|     // unconditionally add the tensor to inputs... Dynamo will de-dupe them | ||||
|     // later | ||||
|     inputs.emplace_back(tensor); | ||||
|     return lookup(tensor, true); | ||||
|   } | ||||
|  | ||||
| @ -208,6 +237,11 @@ struct LiftedIValueArgs { | ||||
|     return iv_arg.proxy; | ||||
|   } | ||||
|  | ||||
|   at::IValue& next_proxy() { | ||||
|     auto& iv_arg = args.at(next++); | ||||
|     return iv_arg.proxy; | ||||
|   } | ||||
|  | ||||
|   void add(const at::IValue* iv) { | ||||
|     args.emplace_back(iv); | ||||
|     if (active_node_call_idx.has_value()) { | ||||
| @ -278,13 +312,16 @@ class CompiledNodeArgs { | ||||
|   } | ||||
|  | ||||
|   void collect(const at::Tensor& t) { | ||||
|     _node_call.collection_info().num_saved_tensors++; | ||||
|     collect(_compiler.tensor_args.add(t)); | ||||
|   } | ||||
|   void collect(const SavedVariable& sv, bool is_output) { | ||||
|     _node_call.collection_info().num_saved_tensors++; | ||||
|     collect( | ||||
|         _compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr)); | ||||
|   } | ||||
|   void collect(const c10::SymInt& t) { | ||||
|     _node_call.collection_info().num_saved_sizes++; | ||||
|     _compiler.add_size_input(t); | ||||
|   } | ||||
|   void collect(const std::vector<SavedVariable>& t, bool is_output) { | ||||
| @ -366,6 +403,7 @@ class CompiledNodeArgs { | ||||
|         !nested && | ||||
|         (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) { | ||||
|       // can't lift ivalues nested in collections | ||||
|       _node_call.collection_info().num_saved_ivalues++; | ||||
|       _compiler.lifted_ivalue_args.add(&iv); | ||||
|     } else { | ||||
|       try { | ||||
| @ -629,17 +667,110 @@ struct TraceState { | ||||
|   variable_list outputs; | ||||
| }; | ||||
|  | ||||
| struct TORCH_API SwapInterface { | ||||
|   virtual ~SwapInterface() = default; | ||||
|   virtual std::optional<at::Tensor> tensor(const at::Tensor& tensor) = 0; | ||||
|   virtual std::optional<at::Tensor> tensor(const SavedVariable& tensor) = 0; | ||||
|   virtual std::optional<c10::SymInt> next_size() = 0; | ||||
|   virtual c10::IValue next_ivalue() = 0; | ||||
| }; | ||||
|  | ||||
| struct SwapWithProxies : public SwapInterface { | ||||
|   explicit SwapWithProxies(AutogradCompilerCall& compiler, TraceState& state) | ||||
|       : compiler_(compiler), state_(state) {} | ||||
|  | ||||
|   ~SwapWithProxies() override = default; | ||||
|  | ||||
|   std::optional<at::Tensor> tensor(const at::Tensor& tensor) override { | ||||
|     TensorArg& arg = compiler_.tensor_args.lookup(tensor); | ||||
|     if (arg.defined()) { | ||||
|       TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); | ||||
|       return arg.proxy_tensor; | ||||
|     } | ||||
|     return std::nullopt; | ||||
|   } | ||||
|  | ||||
|   std::optional<at::Tensor> tensor(const SavedVariable& t) override { | ||||
|     TensorArg& arg = compiler_.tensor_args.lookup(t); | ||||
|     if (arg.defined()) { | ||||
|       return arg.proxy_tensor; | ||||
|     } | ||||
|     return std::nullopt; | ||||
|   } | ||||
|  | ||||
|   std::optional<c10::SymInt> next_size() override { | ||||
|     return state_.next_sym_size(); | ||||
|   } | ||||
|   c10::IValue next_ivalue() override { | ||||
|     return compiler_.lifted_ivalue_args.next_proxy(); | ||||
|   } | ||||
|  | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|   AutogradCompilerCall& compiler_; | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|   TraceState& state_; | ||||
| }; | ||||
|  | ||||
| // The previous compiled autograd implementation was about swapping in | ||||
| // ProxyTensors for a node. Given a single node and some saved | ||||
| // tensors/sizes/scalars, we needed some way to swap in those saved | ||||
| // tensors/sizes/scalars. That's what SwapWithReal is. | ||||
| struct SwapWithReal : public SwapInterface { | ||||
|   explicit SwapWithReal( | ||||
|       std::vector<at::Tensor> tensors, | ||||
|       std::vector<std::optional<c10::SymInt>> sizes, | ||||
|       std::vector<c10::IValue> ivalues) | ||||
|       : tensors_(std::move(tensors)), | ||||
|         sizes_(std::move(sizes)), | ||||
|         ivalues_(std::move(ivalues)) {} | ||||
|  | ||||
|   ~SwapWithReal() override = default; | ||||
|  | ||||
|   std::optional<at::Tensor> tensor(const at::Tensor& _ignored) override { | ||||
|     auto result = tensors_[tensors_idx]; | ||||
|     tensors_idx++; | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   std::optional<at::Tensor> tensor(const SavedVariable& _ignored) override { | ||||
|     TORCH_INTERNAL_ASSERT(tensors_idx < tensors_.size()); | ||||
|     auto result = tensors_[tensors_idx]; | ||||
|     tensors_idx++; | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   std::optional<c10::SymInt> next_size() override { | ||||
|     TORCH_INTERNAL_ASSERT(sizes_idx < sizes_.size()); | ||||
|     auto result = sizes_[sizes_idx]; | ||||
|     sizes_idx++; | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   c10::IValue next_ivalue() override { | ||||
|     TORCH_INTERNAL_ASSERT(ivalues_idx < ivalues_.size()); | ||||
|     auto result = ivalues_[ivalues_idx]; | ||||
|     ivalues_idx++; | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   std::vector<at::Tensor> tensors_; | ||||
|   size_t tensors_idx = 0; | ||||
|   std::vector<std::optional<c10::SymInt>> sizes_; | ||||
|   size_t sizes_idx = 0; | ||||
|   std::vector<c10::IValue> ivalues_; | ||||
|   size_t ivalues_idx = 0; | ||||
| }; | ||||
|  | ||||
| class SwapSavedVariables { | ||||
|   // SwapSavedVariables is used during the tracing/compilation phase after a | ||||
|   // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes, | ||||
|   // allows tracing to happen, then swaps them back afterwards. | ||||
|  public: | ||||
|   void before(at::Tensor& t) { | ||||
|     TensorArg& arg = compiler.tensor_args.lookup(t); | ||||
|     auto replacement = state->tensor(t); | ||||
|     stashed_tensors.save(&t, std::move(t)); | ||||
|     if (arg.defined()) { | ||||
|       TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); | ||||
|       t = arg.proxy_tensor; | ||||
|     if (replacement.has_value()) { | ||||
|       t = *replacement; | ||||
|     } | ||||
|   } | ||||
|   void after(at::Tensor& t) { | ||||
| @ -647,12 +778,11 @@ class SwapSavedVariables { | ||||
|   } | ||||
|  | ||||
|   void before(SavedVariable& t) { | ||||
|     TensorArg& arg = compiler.tensor_args.lookup(t); | ||||
|     auto replacement = state->tensor(t); | ||||
|     stashed_variables.save(&t, std::move(t)); | ||||
|     if (arg.defined()) { | ||||
|     if (replacement.has_value()) { | ||||
|       bool prior = at::SavedTensorDefaultHooks::set_tracing(true); | ||||
|       TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); | ||||
|       t = SavedVariable(arg.proxy_tensor, false); | ||||
|       t = SavedVariable(replacement.value(), false); | ||||
|       at::SavedTensorDefaultHooks::set_tracing(prior); | ||||
|     } | ||||
|   } | ||||
| @ -662,7 +792,7 @@ class SwapSavedVariables { | ||||
|  | ||||
|   void before(c10::SymInt& t) { | ||||
|     stashed_symints.save(&t, c10::SymInt(t)); | ||||
|     auto opt_value = state.next_sym_size(); | ||||
|     auto opt_value = state->next_size(); | ||||
|     if (opt_value.has_value()) { | ||||
|       t = *opt_value; // dynamic shape | ||||
|     } | ||||
| @ -677,7 +807,7 @@ class SwapSavedVariables { | ||||
|     } else { | ||||
|       stashed_ivalues.save(&iv, at::IValue(iv)); | ||||
|       if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) { | ||||
|         iv = compiler.lifted_ivalue_args.next_proxy(&iv); | ||||
|         iv = state->next_ivalue(); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| @ -824,7 +954,23 @@ class SwapSavedVariables { | ||||
|       TraceState& s, | ||||
|       PyObject* p, | ||||
|       const NodeCall& n) | ||||
|       : compiler(c), state(s), py_compiler(p), curr_node_call(n) {} | ||||
|       : py_compiler(p), curr_node_call(n) { | ||||
|     state = std::make_shared<SwapWithProxies>(c, s); | ||||
|   } | ||||
|  | ||||
|   SwapSavedVariables( | ||||
|       std::vector<at::Tensor> a, | ||||
|       std::vector<std::optional<at::SymInt>> b, | ||||
|       std::vector<at::IValue> c, | ||||
|       PyObject* p, | ||||
|       const NodeCall& n) | ||||
|       : state(std::static_pointer_cast<SwapInterface>( | ||||
|             std::make_shared<SwapWithReal>( | ||||
|                 std::move(a), | ||||
|                 std::move(b), | ||||
|                 std::move(c)))), | ||||
|         py_compiler(p), | ||||
|         curr_node_call(n) {} | ||||
|  | ||||
|   PyObject* get_py_compiler() { | ||||
|     return py_compiler; | ||||
| @ -875,9 +1021,10 @@ class SwapSavedVariables { | ||||
|   }; | ||||
|  | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|   AutogradCompilerCall& compiler; | ||||
|   // AutogradCompilerCall& compiler; | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|   TraceState& state; | ||||
|   std::shared_ptr<SwapInterface> state; | ||||
|   // TraceState& state; | ||||
|   // This is a borrowed reference, we do not increment ownership, or lower it, | ||||
|   // it's lifecycle is entirely longer than this objects. | ||||
|   PyObject* py_compiler; | ||||
|  | ||||
| @ -451,6 +451,37 @@ void set_ivalue_proxies( | ||||
|   } | ||||
| } | ||||
|  | ||||
| static PyObject* call_capture( | ||||
|     PyObject* self, | ||||
|     CacheNode& cache, | ||||
|     AutogradCompilerCall& compiler_call, | ||||
|     size_t num_outputs, | ||||
|     PyObject* nodecalls) { | ||||
|   static PyObject* method_name = PyUnicode_InternFromString("capture"); | ||||
|   THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs)); | ||||
|  | ||||
|   THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs()); | ||||
|   std::vector<std::optional<c10::SymInt>> dynamic_inputs = | ||||
|       cache.unwrap_dynamic_inputs(py::cast<py::list>(pysizeinput.get()).ptr()); | ||||
|  | ||||
|   THPObjectPtr pyivalueargsinput( | ||||
|       wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args)); | ||||
|   THPObjectPtr pynodeorigins( | ||||
|       wrap_node_origins(compiler_call, PyTuple_GET_SIZE(pysizeinput.get()))); | ||||
|   PyObject* py_num_outputs = THPUtils_packUInt32(num_outputs); | ||||
|   return check(PyObject_CallMethodObjArgs( | ||||
|       self, | ||||
|       method_name, | ||||
|       pyinput.get(), | ||||
|       // TODO(rzou): is this leaking memory? | ||||
|       py::cast(dynamic_inputs).ptr(), | ||||
|       pyivalueargsinput.get(), | ||||
|       pynodeorigins.get(), | ||||
|       nodecalls, | ||||
|       py_num_outputs, | ||||
|       nullptr)); | ||||
| } | ||||
|  | ||||
| static TraceState call_begin_capture( | ||||
|     PyObject* self, | ||||
|     CacheNode& cache, | ||||
| @ -552,7 +583,9 @@ CacheNode* _compiled_autograd_impl( | ||||
|         compiler_call.set_active_node_call_idx(i); | ||||
|       } | ||||
|       if (node_args.cond(call.needed)) { | ||||
|         call.mode = CollectionMode::COMPILED_ARGS; | ||||
|         fn->compiled_args(node_args); | ||||
|         call.mode = CollectionMode::NEXT_EDGES; | ||||
|         node_args.collect(call.node->next_edges()); | ||||
|       } | ||||
|       CacheKey key = node_args.key(); | ||||
| @ -600,112 +633,15 @@ CacheNode* _compiled_autograd_impl( | ||||
|     ClosingTHPObjectPtr py_compiler( | ||||
|         check(PyObject_CallNoArgs((the_autograd_compiler)))); | ||||
|  | ||||
|     TraceState state = call_begin_capture( | ||||
|         py_compiler, *cache, compiler_call, output_edges.size()); | ||||
|     InputBuffers input_buffers; | ||||
|     // nodes | ||||
|     py::object nodecalls = py::cast(calls); | ||||
|     PyObject* res = call_capture( | ||||
|         py_compiler, | ||||
|         *cache, | ||||
|         compiler_call, | ||||
|         output_edges.size(), | ||||
|         nodecalls.ptr()); | ||||
|  | ||||
|     for (size_t i = 0; i < calls.size(); i++) { | ||||
|       NodeCall& call = *calls[i]; | ||||
|       // TODO(jansel): consider adding some of this stuff: | ||||
|       // guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto | ||||
|       // opt_parent_stream = (*func).stream(c10::DeviceType::CUDA); | ||||
|       // c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream}; | ||||
|       // CheckpointValidGuard cpvguard(graph_task); | ||||
|       // at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); | ||||
|       // if (C10_UNLIKELY(step_callbacks.has_value())) { ... } | ||||
|  | ||||
|       variable_list inputs = | ||||
|           std::move(input_buffers.lookup(call.node.get()).buffer); | ||||
|       input_buffers.erase(call.node.get()); | ||||
|  | ||||
|       if (!call.tensor_pre_hooks.empty()) { | ||||
|         THPObjectPtr pyinputs(THPVariable_WrapList(inputs)); | ||||
|         for (const auto& hook : call.tensor_pre_hooks) { | ||||
|           pyinputs = check(PyObject_CallMethod( | ||||
|               py_compiler, | ||||
|               "tensor_pre_hook", | ||||
|               "Oii", | ||||
|               pyinputs.get(), | ||||
|               hook.first, | ||||
|               hook.second)); | ||||
|         } | ||||
|         inputs = THPVariable_UnpackList(pyinputs); | ||||
|       } | ||||
|       for (const auto& graph_output : call.graph_output) { | ||||
|         int input_nr = graph_output.first; | ||||
|         int output_index = graph_output.second; | ||||
|         TORCH_INTERNAL_ASSERT( | ||||
|             output_index < static_cast<int>(state.outputs.size())); | ||||
|         TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined()); | ||||
|         state.outputs[output_index] = inputs[input_nr]; | ||||
|       } | ||||
|       if (!call.needed) { | ||||
|         continue; | ||||
|       } | ||||
|       if (!call.pre_hooks.empty()) { | ||||
|         THPObjectPtr pyinputs(THPVariable_WrapList(inputs)); | ||||
|         for (const auto hook : call.pre_hooks) { | ||||
|           pyinputs = check(PyObject_CallMethod( | ||||
|               py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook)); | ||||
|         } | ||||
|         inputs = THPVariable_UnpackList(pyinputs); | ||||
|       } | ||||
|  | ||||
|       std::string _node_name = call.node->name(); | ||||
|       THPObjectPtr node_name(PyUnicode_FromString(_node_name.data())); | ||||
|       TORCH_INTERNAL_ASSERT(node_name != nullptr); | ||||
|       THPObjectPtr set_node_origin( | ||||
|           PyObject_GetAttrString(py_compiler.get(), "set_node_origin")); | ||||
|  | ||||
|       PyObject* pyobj = Py_None; | ||||
|       if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) { | ||||
|         pyobj = pynode->obj; | ||||
|       } | ||||
|  | ||||
|       check(PyObject_CallFunction( | ||||
|           set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr)); | ||||
|  | ||||
|       SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call); | ||||
|       variable_list outputs = call.node->apply_with_saved(inputs, saved); | ||||
|  | ||||
|       saved.debug_asserts(); | ||||
|       saved.before(call.node->next_edges()); | ||||
|       validate_outputs( | ||||
|           call.node->next_edges(), outputs, [&](const std::string& msg) { | ||||
|             std::ostringstream ss; | ||||
|             ss << "[Compiled Autograd Tracing: " << call.node->name() << "] " | ||||
|                << msg; | ||||
|             return ss.str(); | ||||
|           }); | ||||
|       saved.after(call.node->next_edges()); | ||||
|       saved.debug_asserts(); | ||||
|  | ||||
|       if (!call.post_hooks.empty()) { | ||||
|         THPObjectPtr pyinputs(THPVariable_WrapList(inputs)); | ||||
|         THPObjectPtr pyoutputs(THPVariable_WrapList(outputs)); | ||||
|         for (const auto hook : call.post_hooks) { | ||||
|           pyoutputs = check(PyObject_CallMethod( | ||||
|               py_compiler.get(), | ||||
|               "post_hook", | ||||
|               "OOi", | ||||
|               pyoutputs.get(), | ||||
|               pyinputs.get(), | ||||
|               hook)); | ||||
|         } | ||||
|         outputs = THPVariable_UnpackList(pyoutputs); | ||||
|       } | ||||
|       for (const auto i : c10::irange(outputs.size())) { | ||||
|         auto& output = outputs[i]; | ||||
|         const auto& next = call.node->next_edge(i); | ||||
|         if (next.is_valid() && output.defined()) { | ||||
|           input_buffers.lookup(next.function.get()) | ||||
|               .add( | ||||
|                   next.input_nr, std::move(output), std::nullopt, std::nullopt); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     PyObject* res = check(call_end_capture(py_compiler, state.outputs)); | ||||
|     TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple"); | ||||
|     TORCH_CHECK( | ||||
|         PyTuple_Size(res) == 2, | ||||
| @ -718,15 +654,25 @@ CacheNode* _compiled_autograd_impl( | ||||
|     TORCH_CHECK( | ||||
|         PyCallable_Check(cache->compiled_fn), | ||||
|         "Expected end_capture to return compiled_fn"); | ||||
|     state.debug_asserts(); | ||||
|     // TODO(rzou): what is this? | ||||
|     // state.debug_asserts(); | ||||
|   } // End cache miss region | ||||
|  | ||||
|   // TODO(rzou): need some mechanism to release the variables when we're ready. | ||||
|   // TODO(jansel): clear grads we will overwrite below | ||||
|   if (!graph_task.keep_graph_) { | ||||
|     for (auto& call : calls) { | ||||
|       call->node->release_variables(); | ||||
|     } | ||||
|   // if (!graph_task.keep_graph_) { | ||||
|   //   for (auto& call : calls) { | ||||
|   //     call->node->release_variables(); | ||||
|   //   } | ||||
|   // } | ||||
|  | ||||
|   // TODO(rzou): we probably shouldn't be copying the nodes in the hot path? | ||||
|   std::vector<NodeCall> persistent_node_calls; | ||||
|   for (NodeCall* call : calls) { | ||||
|     persistent_node_calls.push_back(*call); | ||||
|   } | ||||
|   auto ca = py::module::import("torch._dynamo.compiled_autograd"); | ||||
|   ca.attr("set_global_nodecalls")(persistent_node_calls); | ||||
|  | ||||
|   *graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs); | ||||
|   *graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs); | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	