mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fr] [xpu] Add FlightRecorder support for ProcessGroupXCCL (#158568)
Adds support for FlightRecorder in ProcessGroupXCCL. See https://github.com/intel/torch-xpu-ops/pull/1867 for XCCL implementation and more details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158568 Approved by: https://github.com/guangyey, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
9e491f753e
commit
9b4adc4db7
@ -386,7 +386,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
ASSERT_TRUE(
|
||||
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
|
||||
auto tempFilename = c10::str(
|
||||
std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
|
||||
std::filesystem::temp_directory_path().string(), "/comm_lib_trace_rank_");
|
||||
ASSERT_TRUE(
|
||||
setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
|
||||
// Enable nccl flight recorder.
|
||||
@ -401,7 +401,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
// The only difference is that we are storing traces also in memory for
|
||||
// validation.
|
||||
std::string fileNamePrefix = c10d::getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/comm_lib_trace_rank_");
|
||||
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
|
||||
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
|
||||
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
|
||||
|
@ -2449,7 +2449,7 @@ class ProcessGroupGlooFRTest(ProcessGroupGlooTest):
|
||||
|
||||
def _verify_trace(self, t, is_json):
|
||||
ver = t["version"]
|
||||
self.assertEqual(ver, "2.9")
|
||||
self.assertEqual(ver, "2.10")
|
||||
pg_config = t["pg_config"]
|
||||
self.assertEqual(len(pg_config), 1)
|
||||
default_pg_info = pg_config["0"]
|
||||
|
@ -4361,10 +4361,12 @@ class NCCLTraceTestBase(MultiProcessTestCase):
|
||||
class NCCLTraceTest(NCCLTraceTestBase):
|
||||
def _verify_trace(self, t, include_collectives, timing_enabled, is_json):
|
||||
ver = t["version"]
|
||||
self.assertEqual(ver, "2.9")
|
||||
nccl_version = t["nccl_version"]
|
||||
torch_nccl_version = torch.cuda.nccl.version()
|
||||
self.assertEqual(nccl_version, ".".join(str(v) for v in torch_nccl_version))
|
||||
self.assertEqual(ver, "2.10")
|
||||
comm_lib_version = t["comm_lib_version"]
|
||||
torch_comm_lib_version = torch.cuda.nccl.version()
|
||||
self.assertEqual(
|
||||
comm_lib_version, ".".join(str(v) for v in torch_comm_lib_version)
|
||||
)
|
||||
pg_config = t["pg_config"]
|
||||
self.assertEqual(len(pg_config), 1)
|
||||
default_pg_info = pg_config["0"]
|
||||
|
2
third_party/xpu.txt
vendored
2
third_party/xpu.txt
vendored
@ -1 +1 @@
|
||||
77cc792cd265179745d335579d233e6d4f9a2667
|
||||
77cc792cd265179745d335579d233e6d4f9a2667
|
@ -388,8 +388,10 @@ class Op:
|
||||
self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str
|
||||
):
|
||||
self.profiling_name = event["profiling_name"]
|
||||
nccl, name = self.profiling_name.split(":")
|
||||
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
|
||||
comm_lib_backend, name = self.profiling_name.split(":")
|
||||
assert comm_lib_backend in ["nccl", "xccl"], (
|
||||
f"name formatting error? {comm_lib_backend} != 'nccl' or 'xccl'"
|
||||
)
|
||||
parts = name.split(" ")
|
||||
type = parts[0]
|
||||
meta = parts[1] if len(parts) == 2 else None
|
||||
|
@ -298,6 +298,8 @@ class Backend:
|
||||
def _timeout(self) -> timedelta: ...
|
||||
@_timeout.setter
|
||||
def _timeout(self, val: timedelta) -> None: ...
|
||||
global_ranks_in_group: list[int]
|
||||
group_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -608,8 +610,6 @@ class ProcessGroupGloo(Backend):
|
||||
class Options(Backend.Options):
|
||||
devices: list[ProcessGroupGloo.Device]
|
||||
threads: int
|
||||
global_ranks_in_group: list[int]
|
||||
group_name: str
|
||||
|
||||
def __init__(self): ...
|
||||
|
||||
@ -651,8 +651,6 @@ class ProcessGroupNCCL(Backend):
|
||||
is_high_priority_stream: bool
|
||||
split_from: ProcessGroupNCCL
|
||||
split_color: int
|
||||
global_ranks_in_group: list[int]
|
||||
group_name: str
|
||||
|
||||
def __init__(self, is_high_priority_stream: bool = False): ...
|
||||
|
||||
@ -830,12 +828,18 @@ class _SymmetricMemory:
|
||||
def signal_pad_size(self) -> int: ...
|
||||
|
||||
class ProcessGroupXCCL(Backend):
|
||||
class Options(Backend.Options):
|
||||
def __init__(self): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: Store,
|
||||
rank: int,
|
||||
size: int,
|
||||
): ...
|
||||
options: Options,
|
||||
) -> None: ...
|
||||
@property
|
||||
def options(self) -> Options: ... # type: ignore[override]
|
||||
|
||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||
def _current_process_group() -> ProcessGroup: ...
|
||||
|
@ -47,6 +47,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const std::string backend;
|
||||
std::string group_name;
|
||||
std::vector<uint64_t> global_ranks_in_group;
|
||||
};
|
||||
|
||||
explicit Backend(int rank, int size);
|
||||
|
@ -39,7 +39,7 @@ DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
|
||||
auto cacheDirPath = std::filesystem::path(homeDir + "/.cache/torch");
|
||||
// Create the .cache directory if it doesn't exist
|
||||
std::filesystem::create_directories(cacheDirPath);
|
||||
auto defaultLocation = cacheDirPath / "nccl_trace_rank_";
|
||||
auto defaultLocation = cacheDirPath / "comm_lib_trace_rank_";
|
||||
|
||||
// For internal bc compatibility, we keep the old the ENV check.
|
||||
std::string fileNamePrefix = getCvarString(
|
||||
|
@ -20,10 +20,10 @@ namespace c10d {
|
||||
// (minor when adding fields, major when changing existing fields)
|
||||
// Also update both JSON and Pickle dumps to make use of the newly defined
|
||||
// field(s).
|
||||
DEFINE_CONSTANT(version_val, "2.9")
|
||||
DEFINE_CONSTANT(version_val, "2.10")
|
||||
DEFINE_CONSTANT(entries_key, "entries")
|
||||
DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state")
|
||||
DEFINE_CONSTANT(nccl_version_key, "nccl_version")
|
||||
DEFINE_CONSTANT(comm_lib_version_key, "comm_lib_version")
|
||||
DEFINE_CONSTANT(version_key, "version")
|
||||
DEFINE_CONSTANT(pg_config_key, "pg_config")
|
||||
DEFINE_CONSTANT(pg_status_key, "pg_status")
|
||||
@ -179,7 +179,7 @@ struct FlightRecorder {
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_ = {};
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_ = {};
|
||||
std::string nccl_version_;
|
||||
std::string comm_lib_version_;
|
||||
|
||||
std::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
@ -200,7 +200,7 @@ struct FlightRecorder {
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
std::vector<uint64_t> ranks);
|
||||
|
||||
void record_accelerator_version(const std::string nccl_version);
|
||||
void record_accelerator_version(const std::string comm_lib_version);
|
||||
|
||||
void update_state(Entry& r);
|
||||
|
||||
|
@ -128,12 +128,12 @@ void FlightRecorder<EventType>::record_pg_ranks(
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::record_accelerator_version(
|
||||
const std::string nccl_version) {
|
||||
const std::string comm_lib_version) {
|
||||
if (!enabled_) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
nccl_version_ = std::move(nccl_version);
|
||||
comm_lib_version_ = std::move(comm_lib_version);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
@ -425,7 +425,7 @@ std::string FlightRecorder<EventType>::dump_json(
|
||||
bool onlyActive) {
|
||||
json result;
|
||||
result[version_key_str] = version_val_str;
|
||||
result[nccl_version_key_str] = nccl_version_;
|
||||
result[comm_lib_version_key_str] = comm_lib_version_;
|
||||
result[pg_config_key_str] = getPgConfigJson();
|
||||
result[pg_status_key_str] = getPgStatusJson();
|
||||
|
||||
@ -522,7 +522,7 @@ std::string FlightRecorder<EventType>::dump(
|
||||
// common values
|
||||
result.insert(version_key, version_val);
|
||||
result.insert(pg_config_key, getPgConfig());
|
||||
result.insert(nccl_version_key_str, nccl_version_);
|
||||
result.insert(comm_lib_version_key_str, comm_lib_version_);
|
||||
result.insert(pg_status_key, getPgStatus());
|
||||
|
||||
// collective trace
|
||||
|
@ -255,7 +255,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
return c10::make_intrusive<Options>(timeout);
|
||||
}
|
||||
|
||||
std::vector<uint64_t> global_ranks_in_group;
|
||||
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
|
||||
int threads;
|
||||
};
|
||||
|
@ -545,7 +545,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
|
||||
int split_color{-2};
|
||||
#endif
|
||||
std::vector<uint64_t> global_ranks_in_group;
|
||||
};
|
||||
|
||||
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
|
||||
|
@ -3086,7 +3086,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
py::arg("backend"),
|
||||
py::arg("timeout") = kProcessGroupDefaultTimeout)
|
||||
.def_readonly("backend", &::c10d::Backend::Options::backend)
|
||||
.def_readwrite("_timeout", &::c10d::Backend::Options::timeout);
|
||||
.def_readwrite("_timeout", &::c10d::Backend::Options::timeout)
|
||||
.def_readwrite(
|
||||
"global_ranks_in_group",
|
||||
&::c10d::Backend::Options::global_ranks_in_group)
|
||||
.def_readwrite("group_name", &::c10d::Backend::Options::group_name);
|
||||
|
||||
#ifdef USE_C10D_GLOO
|
||||
static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
|
||||
@ -3102,12 +3106,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
processGroupGloo, "_Options", backendOptions)
|
||||
.def(py::init<>())
|
||||
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
|
||||
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads)
|
||||
.def_readwrite(
|
||||
"global_ranks_in_group",
|
||||
&::c10d::ProcessGroupGloo::Options::global_ranks_in_group)
|
||||
.def_readwrite(
|
||||
"group_name", &::c10d::ProcessGroupGloo::Options::group_name);
|
||||
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
|
||||
|
||||
processGroupGloo
|
||||
.def_static(
|
||||
@ -3469,11 +3468,6 @@ Example::
|
||||
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
|
||||
.def_readwrite(
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
|
||||
.def_readwrite(
|
||||
"global_ranks_in_group",
|
||||
&::c10d::ProcessGroupNCCL::Options::global_ranks_in_group)
|
||||
.def_readwrite(
|
||||
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name)
|
||||
.def(
|
||||
"__copy__",
|
||||
[](const ::c10d::ProcessGroupNCCL::Options& self) {
|
||||
@ -3512,17 +3506,49 @@ Example::
|
||||
.def(
|
||||
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
int rank,
|
||||
int size) {
|
||||
int size,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroupXCCL::Options>
|
||||
options) {
|
||||
// gil_scoped_release is not safe as a call_guard in init.
|
||||
// https://github.com/pybind/pybind11/issues/5473
|
||||
py::gil_scoped_release nogil{};
|
||||
|
||||
return c10::make_intrusive<::c10d::ProcessGroupXCCL>(
|
||||
store, rank, size);
|
||||
store, rank, size, std::move(options));
|
||||
}),
|
||||
py::arg("store"),
|
||||
py::arg("rank"),
|
||||
py::arg("size"));
|
||||
py::arg("size"),
|
||||
py::arg("options"),
|
||||
R"(Create a new ProcessGroupXCCL instance.)");
|
||||
|
||||
intrusive_ptr_class_<::c10d::ProcessGroupXCCL::Options>(
|
||||
processGroupXCCL, "Options", backendOptions)
|
||||
.def(py::init<>());
|
||||
module
|
||||
.def(
|
||||
"_dump_xccl_trace",
|
||||
[](std::optional<bool> includeCollectives,
|
||||
std::optional<bool> includeStackTraces,
|
||||
std::optional<bool> onlyActive) {
|
||||
return py::bytes(::c10d::dump_xccl_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
},
|
||||
py::arg("includeCollectives") = std::optional<bool>(),
|
||||
py::arg("includeStackTraces") = std::optional<bool>(),
|
||||
py::arg("onlyActive") = std::optional<bool>(),
|
||||
R"(
|
||||
Arguments:
|
||||
includeCollectives(bool, optional): Whether to include collective work traces. Default is True.
|
||||
includeStackTraces(bool, optional): Whether to include stacktraces in the collective work traces. Default is True.
|
||||
onlyActive (bool, optional): Whether to only include active collective work traces. Default is False.
|
||||
Returns:
|
||||
Stringified pickle work traces.
|
||||
Default settings return everything - i.e. contains XCCL comm dumps and collective traces.
|
||||
)")
|
||||
.def("get_xccl_version", [] { return ::c10d::getXcclVersion(); });
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_UCC
|
||||
|
@ -2035,8 +2035,12 @@ def _new_process_group_helper(
|
||||
elif backend_str == Backend.XCCL:
|
||||
if not is_xccl_available():
|
||||
raise RuntimeError("Distributed package doesn't have XCCL built in")
|
||||
backend_options = ProcessGroupXCCL.Options()
|
||||
backend_options.global_ranks_in_group = global_ranks_in_group
|
||||
backend_options.group_name = group_name
|
||||
backend_options._timeout = timeout
|
||||
backend_class = ProcessGroupXCCL(
|
||||
backend_prefix_store, group_rank, group_size
|
||||
backend_prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.XCCL
|
||||
else:
|
||||
|
Reference in New Issue
Block a user