mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156321 Approved by: https://github.com/jingsh ghstack dependencies: #156313, #156314, #156315, #156316, #156317, #156319
143 lines
4.3 KiB
C++
143 lines
4.3 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/core/ivalue.h>
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <utility>
|
|
|
|
namespace c10d {
|
|
|
|
// Broadcast many tensors to all processes in the process group.
|
|
TORCH_API void broadcast_coalesced(
|
|
const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
|
|
at::TensorList tensors,
|
|
size_t buffer_size,
|
|
int rank = 0);
|
|
|
|
// This class passes bucket contents tensor to DDP communication hook.
|
|
class TORCH_API GradBucket {
|
|
public:
|
|
explicit GradBucket(
|
|
size_t index,
|
|
size_t bucket_count,
|
|
at::Tensor tensor,
|
|
std::vector<size_t> offsets,
|
|
std::vector<size_t> lengths,
|
|
std::vector<c10::IntArrayRef> sizes_vec,
|
|
std::vector<at::Tensor> parameters,
|
|
std::optional<at::Tensor> sparse_grad_indices)
|
|
: index_(index),
|
|
bucket_count_(bucket_count),
|
|
buffer_(std::move(tensor)),
|
|
offsets_(std::move(offsets)),
|
|
lengths_(std::move(lengths)),
|
|
sizes_vec_(std::move(sizes_vec)),
|
|
parameters_(std::move(parameters)),
|
|
sparse_grad_indices_(std::move(sparse_grad_indices)) {}
|
|
|
|
// Returns the index of the bucket, which is unique across all the buckets.
|
|
size_t getIndex() const {
|
|
return index_;
|
|
}
|
|
|
|
const at::Tensor& getBuffer() const {
|
|
return buffer_;
|
|
}
|
|
|
|
// Returns a mutable buffer compared with the above method.
|
|
at::Tensor& getBufferRef() {
|
|
return buffer_;
|
|
}
|
|
|
|
// Overwrites the buffer at a specific index.
|
|
void setBuffer(at::Tensor& buffer) {
|
|
buffer_ = buffer;
|
|
}
|
|
|
|
// Each tensor in the list that getGradients corresponds to a
|
|
// parameter.
|
|
std::vector<at::Tensor> getGradients() const;
|
|
|
|
// Returns model parameters belonging to this bucket. They are returned in the
|
|
// same order as gradient tensors via getGradients(). For example,
|
|
// getParameters[i] will have its gradient stored in
|
|
// getGradients[i]
|
|
const std::vector<at::Tensor> getParameters() const {
|
|
return parameters_;
|
|
}
|
|
|
|
// Returns whether this bucket is the last bucket to allreduce in an
|
|
// iteration.
|
|
bool isLast() const {
|
|
return index_ == bucket_count_ - 1;
|
|
}
|
|
|
|
std::optional<at::Tensor>& getSparseGradIndices() {
|
|
return sparse_grad_indices_;
|
|
}
|
|
|
|
private:
|
|
size_t index_;
|
|
size_t bucket_count_;
|
|
at::Tensor buffer_;
|
|
|
|
// Per-variable info in buffer_.
|
|
std::vector<size_t> offsets_;
|
|
std::vector<size_t> lengths_;
|
|
std::vector<c10::IntArrayRef> sizes_vec_;
|
|
|
|
// Model parameters for this bucket.
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
|
const std::vector<at::Tensor> parameters_;
|
|
|
|
// Predefined sparse indices for this bucket (only used for sparse tensors).
|
|
// The gradients will be updated to have indices with these tensor values
|
|
std::optional<at::Tensor> sparse_grad_indices_;
|
|
};
|
|
|
|
// Base class of both `PythonCommHook` and `CppCommHook`.
|
|
// Requires implementing 1) `runHook` method that communicates gradients
|
|
// asynchronously, and 2) `parseHookResult` method that converts the hook
|
|
// result into a tensor.
|
|
class TORCH_API CommHookInterface {
|
|
public:
|
|
virtual ~CommHookInterface() = default;
|
|
|
|
// Passes the input grad bucket to the registered communication hook.
|
|
// Once the tensor in the bucket are ready, kicks off the hook asynchronously
|
|
// and returns a future that holds the communication results.
|
|
virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
|
|
GradBucket& bucket) = 0;
|
|
|
|
// Returns the resulting tensor once the communication hook result is
|
|
// ready. The resulting tensor will then be copied to the grads of
|
|
// individual parameters.
|
|
virtual at::Tensor parseHookResult(const c10::IValue& result) = 0;
|
|
};
|
|
|
|
namespace detail {
|
|
// This helper function is called both by CppCommHookInterface below and inside
|
|
// reducer.
|
|
TORCH_API at::Tensor parseCppCommHookResult(const c10::IValue& result);
|
|
} // namespace detail
|
|
|
|
// This CppCommHook interface only requires implementing runHook method that
|
|
// potentially uses a state.
|
|
template <typename T>
|
|
class CppCommHookInterface : public CommHookInterface {
|
|
public:
|
|
explicit CppCommHookInterface(T state) : state_(std::move(state)) {}
|
|
|
|
~CppCommHookInterface() override = default;
|
|
|
|
at::Tensor parseHookResult(const c10::IValue& result) override {
|
|
return detail::parseCppCommHookResult(result);
|
|
}
|
|
|
|
protected:
|
|
T state_;
|
|
};
|
|
|
|
} // namespace c10d
|