mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24876 This contains very basic functionality of adding 'send' autograd function to our autograd graph. The purpose of this change is to validate the basic structure proposed here makes sense. Once this makes sense, we can build upon this to address more complicated scenarios. At a high level we've added the following functionality: 1) Define a very simple 'SendRpcBackwards' autograd function. 2) Attach this function to appropriate tensors when we call an RPC. 3) Store the send function in our distributed autograd context. ghstack-source-id: 89359708 Test Plan: unit tests. Differential Revision: D16903255 fbshipit-source-id: 6c04794a8e58b199795404225fd9da0c1440460e
22 lines
685 B
C++
22 lines
685 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
|
#include <torch/types.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace autograd {
|
|
|
|
// This method is used to attach the 'send' autograd function to the autograd
|
|
// graph when we use RPC. This method creates a new 'send' autograd function
|
|
// and attaches the provided tensors as next_edges to the 'send' function.
|
|
//
|
|
// Returns a shared_ptr to the autograd function, so that we can hold a
|
|
// reference to it.
|
|
TORCH_API std::shared_ptr<SendRpcBackward> addSendRpcBackward(
|
|
const std::vector<torch::Tensor>& tensors);
|
|
|
|
} // namespace autograd
|
|
} // namespace distributed
|
|
} // namespace torch
|