Files
pytorch/torch/csrc/distributed/c10d/python_comm_hook.h
Min Si 1ad0048b64 Refactor distribuetd to use absolute header path (#85780)
Headers under torch/csrc/distributed may be referened with relative path, e.g., "<c10d/...>". However, relative path cannot be gracefully handled by Meta internal build when the NCCL PG is hipified to support AMD/RCCL because the "hipified" header files are generated in other directories. Moreover, using absolute path for header inclusion is the state-of-the-art in most components in Pytorch. Thus, this patch refactors all header paths in torch/csrc/distributed to be absolute.

See D39835774 for more details about Meta internal complication.

**How to test**: commit 9e5d199 removes -I./torch/csrc/distributed in compile options. Thus use it to verify we don't miss any relative path use of torch/csrc/distributed headers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85780
Approved by: https://github.com/kumpera, https://github.com/huydhn
2022-09-30 05:13:50 +00:00

35 lines
1.0 KiB
C++

#pragma once
#include <torch/csrc/distributed/c10d/comm.hpp>
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
namespace c10d {
class TORCH_PYTHON_API PythonCommHook : public CommHookInterface {
public:
// Takes a state and a callable hook. The inputs are Python objects.
// The state is passed to the hook in runHook method, and it can be used to
// maintain and update any state information during the execution of the hook.
// The hook performs user-specified processing and returns a future indicating
// asychronous communication of gradients.
PythonCommHook(py::object state, py::object hook)
: state_(std::move(state)), hook_(std::move(hook)) {}
~PythonCommHook() override;
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
at::Tensor parseHookResult(const c10::IValue& result) override;
private:
// Only needed for stateful communication.
py::object state_;
py::object hook_;
};
} // namespace c10d