mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	[ca] side-effect free inital trace: compiled_args (#147804)
const methods to prevent accidental mutation. changes mainly in Error nodes and PyNode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147804 Approved by: https://github.com/jansel ghstack dependencies: #147242, #147796
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							5758743f3c
						
					
				
				
					commit
					ec768d8dc0
				
			| @ -34,6 +34,7 @@ | ||||
| #include <torch/csrc/utils/python_strings.h> | ||||
| #include <torch/csrc/utils/tensor_dtypes.h> | ||||
|  | ||||
| #include <autograd/function.h> | ||||
| #include <functional> | ||||
| #include <memory> | ||||
| #include <stdexcept> | ||||
| @ -185,9 +186,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { | ||||
|   return to_variable_list(r.get(), is_variable_input); | ||||
| } | ||||
|  | ||||
| auto PyNode::defer_to_dynamo( | ||||
| auto PyNode::apply_with_saved_impl( | ||||
|     const variable_list& inputs, | ||||
|     const std::optional<PyObject*>& compiler) -> variable_list { | ||||
|     const SwapSavedVariables& saved) -> variable_list { | ||||
|   pybind11::gil_scoped_acquire gil; | ||||
|   at::OptionalDeviceGuard _device_guard; | ||||
|   THPFunction* py_fn = (THPFunction*)obj; | ||||
| @ -235,24 +236,24 @@ auto PyNode::defer_to_dynamo( | ||||
|   } | ||||
|   THPObjectPtr saved_tensors(unpack_saved_variables( | ||||
|       py_fn, [](const Variable& var) { return THPVariable_Wrap(var); })); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       _backward_idx.has_value(), | ||||
|       "indices should already be set by compiled_args, called before apply_with_saved"); | ||||
|  | ||||
|   auto [bwd_idx, maybe_bwd_state_idx] = saved.retrieve_pynode_objs(this); | ||||
|  | ||||
|   PyObject* backward_state_idx = Py_None; | ||||
|   if (_backward_state_idx.has_value()) { | ||||
|     backward_state_idx = THPUtils_packInt64(_backward_state_idx.value()); | ||||
|   if (maybe_bwd_state_idx.has_value()) { | ||||
|     backward_state_idx = THPUtils_packUInt64(maybe_bwd_state_idx.value()); | ||||
|     // this might be simplifiable now that we no longer inline | ||||
|     Py_CLEAR(py_fn->compiled_autograd_backward_state); | ||||
|   } | ||||
|   THPObjectPtr r(PyObject_CallMethod( | ||||
|       // NOLINTNEXTLINE(bugprone-unchecked-optional-access) | ||||
|       compiler.value(), | ||||
|       saved.get_py_compiler(), | ||||
|       "proxy_call_backward", | ||||
|       "OOOiOO", | ||||
|       pyInputs.get(), | ||||
|       fwdInputMetadatas.get(), | ||||
|       saved_tensors.get(), | ||||
|       *_backward_idx, | ||||
|       bwd_idx, | ||||
|       obj, | ||||
|       backward_state_idx)); | ||||
|  | ||||
| @ -301,7 +302,7 @@ bool PyNode::is_aot_backward() const { | ||||
|   return py::hasattr(py::getattr(handle, "_forward_cls"), "_aot_id"); | ||||
| } | ||||
|  | ||||
| void PyNode::compiled_args(CompiledNodeArgs& args) { | ||||
| void PyNode::compiled_args(CompiledNodeArgs& args) const { | ||||
|   static PyObject* method_name = | ||||
|       PyUnicode_InternFromString("_compiled_autograd_key"); | ||||
|   THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr)); | ||||
| @ -346,14 +347,15 @@ void PyNode::compiled_args(CompiledNodeArgs& args) { | ||||
|   args.collect(f->input_info); | ||||
|  | ||||
|   Py_INCREF(obj); | ||||
|   _backward_idx = args.add_backward(c10::SafePyObject(obj, getPyInterpreter())); | ||||
|  | ||||
|   c10::SafePyObject backward_obj(obj, getPyInterpreter()); | ||||
|   std::optional<c10::SafePyObject> backward_state_obj; | ||||
|   PyObject* bw_state = f->compiled_autograd_backward_state; | ||||
|   if (args.cond(bw_state != nullptr)) { | ||||
|     Py_INCREF(bw_state); | ||||
|     _backward_state_idx = args.add_backward_state( | ||||
|         c10::SafePyObject(bw_state, getPyInterpreter())); | ||||
|     backward_state_obj = c10::SafePyObject(bw_state, getPyInterpreter()); | ||||
|   } | ||||
|   args.collect_pynode_objs( | ||||
|       this, std::move(backward_obj), std::move(backward_state_obj)); | ||||
| } | ||||
|  | ||||
| variable_list PyNode::apply_with_saved( | ||||
| @ -368,8 +370,7 @@ variable_list PyNode::apply_with_saved( | ||||
|   saved.before(f->output_info); | ||||
|   saved.before(f->input_info); | ||||
|   f->compiled_autograd_tracing = true; | ||||
|   variable_list result = | ||||
|       defer_to_dynamo(variable_list(inputs), saved.get_py_compiler()); | ||||
|   variable_list result = apply_with_saved_impl(variable_list(inputs), saved); | ||||
|   f->compiled_autograd_tracing = false; | ||||
|   saved.after(f->compiled_autograd_symints); | ||||
|   saved.after(f->saved_variables); | ||||
|  | ||||
		Reference in New Issue
	
	Block a user