mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable coalescing path on XPU and dispatch to XPU tensor barrier if XCCL backend is specified. (#143735)
**Motivation:** - Enable coalescing path on XPU for `batch_isend_irecv`. - If XCCL backend is specified, then construct a XPU tensor to ensure `barrier` dispatch to XCCL backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143735 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
21cbee5d9b
commit
1800f5f461
@ -157,6 +157,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
return backendType_;
|
||||
}
|
||||
|
||||
inline bool backendSupportsSequenceNumbers(BackendType backendType) {
|
||||
if (backendType == BackendType::GLOO || backendType == BackendType::NCCL ||
|
||||
backendType == BackendType::XCCL || backendType == BackendType::UCC)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void startCoalescing(c10::DeviceType deviceType) {
|
||||
// only nccl has implemented startCoalescing so only execute for nccl
|
||||
// backends
|
||||
@ -639,9 +646,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
virtual void setSequenceNumberForGroup() {
|
||||
auto backendType = getBackendType();
|
||||
// TODO: HACK for backend name to get sequence number for that backend.
|
||||
if (backendType == ProcessGroup::BackendType::GLOO ||
|
||||
backendType == ProcessGroup::BackendType::NCCL ||
|
||||
backendType == ProcessGroup::BackendType::UCC) {
|
||||
if (backendSupportsSequenceNumbers(backendType)) {
|
||||
getDefaultBackend()->setSequenceNumberForGroup();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
@ -660,9 +665,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
auto backendType = getBackendType();
|
||||
|
||||
// TODO: HACK for backend name to get sequence number for that backend.
|
||||
if (backendType == ProcessGroup::BackendType::GLOO ||
|
||||
backendType == ProcessGroup::BackendType::NCCL ||
|
||||
backendType == ProcessGroup::BackendType::UCC) {
|
||||
if (backendSupportsSequenceNumbers(backendType)) {
|
||||
return getDefaultBackend()->getSequenceNumberForGroup();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
@ -757,6 +760,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
tensor = at::empty(
|
||||
{1},
|
||||
at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte));
|
||||
} else if (backendType_ == c10d::ProcessGroup::BackendType::XCCL) {
|
||||
// set xpu tensor for override cpu dispatch
|
||||
tensor = at::empty(
|
||||
{1},
|
||||
at::TensorOptions().device(at::DeviceType::XPU).dtype(at::kByte));
|
||||
} else {
|
||||
// Default to using cpu implementation
|
||||
tensor = at::empty(
|
||||
|
@ -1859,10 +1859,9 @@ def _new_process_group_helper(
|
||||
"created, please use a different group name"
|
||||
)
|
||||
|
||||
if device_id is not None and (device_id.index is None or device_id.type != "cuda"):
|
||||
if device_id is not None and (device_id.index is None or device_id.type == "cpu"):
|
||||
raise ValueError(
|
||||
"init_process_group device_id parameter must be a cuda device with an "
|
||||
"id, e.g. cuda:0, not just cuda or cpu"
|
||||
"init_process_group device_id parameter must be an accelerator with an index"
|
||||
)
|
||||
|
||||
# Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
|
||||
|
Reference in New Issue
Block a user