mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24876 This contains very basic functionality of adding 'send' autograd function to our autograd graph. The purpose of this change is to validate the basic structure proposed here makes sense. Once this makes sense, we can build upon this to address more complicated scenarios. At a high level we've added the following functionality: 1) Define a very simple 'SendRpcBackwards' autograd function. 2) Attach this function to appropriate tensors when we call an RPC. 3) Store the send function in our distributed autograd context. ghstack-source-id: 89359708 Test Plan: unit tests. Differential Revision: D16903255 fbshipit-source-id: 6c04794a8e58b199795404225fd9da0c1440460e
		
			
				
	
	
		
			25 lines
		
	
	
		
			897 B
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			25 lines
		
	
	
		
			897 B
		
	
	
	
		
			C++
		
	
	
	
	
	
| #pragma once
 | |
| 
 | |
| #include <torch/csrc/autograd/function.h>
 | |
| 
 | |
| namespace torch {
 | |
| namespace distributed {
 | |
| namespace autograd {
 | |
| 
 | |
| // As part of our distributed autograd implementation, whenever we send an RPC
 | |
| // from one node to another, we add a 'SendRpcBackward' autograd function to the
 | |
| // autograd graph. This is more or less a placeholder function that is used to
 | |
| // kickoff the autograd engine on the current worker on the backward pass. The
 | |
| // edges for this autograd function are the inputs to the RPC method.
 | |
| //
 | |
| // During the backward pass, this function is queued for execution in the
 | |
| // autograd engine which eventually runs the rest of the autograd graph.
 | |
| struct TORCH_API SendRpcBackward : public torch::autograd::Node {
 | |
|   torch::autograd::variable_list apply(
 | |
|       torch::autograd::variable_list&& grads) override;
 | |
| };
 | |
| 
 | |
| } // namespace autograd
 | |
| } // namespace distributed
 | |
| } // namespace torch
 |