[PGNCCL] Add FP8 support (#152706)

NCCL added support for `Float8e4m3` and `Float8e5m2` in 2.24.

NVIDIA GPUs does not seem to support the following "no negative zero" versions: `Float8_e4m3fnuz` and `Float8_e5m2fnuz`, see https://onnx.ai/onnx/technical/float8.html. So we continue to error out for these two upon a reduction op.

Test plan:
- test_allreduce_float8
- test_reduce_scatter_float8

Resolves #148344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152706
Approved by: https://github.com/d4l3k, https://github.com/eqy, https://github.com/fduwjj, https://github.com/cyyever
This commit is contained in:
Ke Wen
2025-05-02 17:14:36 -07:00
committed by PyTorch MergeBot
parent a1516d9e6e
commit 7a2df6a00b
4 changed files with 75 additions and 83 deletions

View File

@ -3273,27 +3273,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
),
)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced_nccl_float8_errors(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
process_group = c10d.distributed_c10d._get_default_group()
device = torch.device(f"cuda:{self.rank:d}")
tensors = [
torch.full(
(60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
).to(torch.float8_e4m3fn)
for i in range(5)
]
with self.assertRaisesRegex(
RuntimeError,
"Float8 dtypes are not currenlty supported for NCCL reductions",
):
torch.distributed.all_reduce_coalesced(tensors, group=process_group)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced_manager_nccl(self):
@ -3685,56 +3664,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_base_k_float8_errors(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
output_tensor = (
torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank)
)
input_tensors = (
torch.arange(self.world_size * 2, dtype=torch.float32)
.to(torch.float8_e4m3fn)
.to(self.rank)
)
input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
with self.assertRaisesRegex(
RuntimeError,
"Float8 dtypes are not currenlty supported for NCCL reductions",
):
dist.reduce_scatter_tensor(output_tensor, input_tensors)
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_tensor_coalesced_float8_errors(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank)
input_tensors = [
torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank)
for _ in range(self.world_size)
]
with self.assertRaisesRegex(
RuntimeError,
"Float8 dtypes are not currenlty supported for NCCL reductions",
):
with dist._coalescing_manager():
for i in range(self.world_size):
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
self.assertEqual(output_tensors, input_tensors[self.rank])
class SetDeviceMethod(Enum):
TORCH_CUDA_SET = auto() # torch.cuda.set_device

View File

@ -28,6 +28,8 @@ from torch.testing._internal.common_distributed import (
init_multigpu_helper,
MultiProcContinousTest,
requires_nccl,
requires_nccl_version,
sm_is_or_higher_than,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
@ -243,6 +245,24 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with NCCL"):
allreduce(tensors, op)
@requires_nccl_version((2, 24), "Need NCCL 2.24+ for Float8")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_allreduce_float8(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
if not sm_is_or_higher_than(device, 9, 0):
self.skipTest("Float8 requires sm >= 90")
numel = 1024
tensor = torch.ones(numel, dtype=torch.float32, device=device).to(
torch.float8_e4m3fn
)
dist.all_reduce(tensor)
expected = (
torch.empty_like(tensor).fill_(self.world_size).to(torch.float8_e4m3fn)
)
torch.testing.assert_close(tensor, expected)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_alltoall_ops_with_cudafree_race(self):
@ -892,6 +912,27 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
# Verification
self.assertEqual(output_t[0], self.rank * self.world_size)
@requires_nccl_version((2, 24), "Need NCCL 2.24+ for Float8")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_float8(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
if not sm_is_or_higher_than(device, 9, 0):
self.skipTest("Float8 requires sm >= 90")
numel = 1024
output_tensor = torch.zeros(numel, dtype=torch.float32, device=device).to(
torch.float8_e5m2
)
input_tensor = torch.ones(
self.world_size * numel, dtype=torch.float32, device=device
).to(torch.float8_e5m2)
dist.reduce_scatter_tensor(output_tensor, input_tensor)
expected = (
torch.empty_like(output_tensor).fill_(self.world_size).to(torch.float8_e5m2)
)
torch.testing.assert_close(output_tensor, expected)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_barrier(self):

View File

@ -70,6 +70,10 @@ static_assert(
#define NCCL_HAS_QOS
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 24, 0)
#define NCCL_SUPPORTS_FP8
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \

View File

@ -66,8 +66,15 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
#ifdef NCCL_SUPPORTS_FP8
{at::kFloat8_e5m2, ncclFloat8e5m2},
{at::kFloat8_e4m3fn, ncclFloat8e4m3},
#else
{at::kFloat8_e5m2, ncclUint8},
{at::kFloat8_e4m3fn, ncclUint8},
#endif
// NVIDIA GPUs does not support the UZ version standing for "no negative
// zero". See https://onnx.ai/onnx/technical/float8.html
{at::kFloat8_e4m3fnuz, ncclUint8},
{at::kFloat8_e5m2fnuz, ncclUint8},
#if HAS_NCCL_BF16_DATATYPE
@ -75,6 +82,17 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
#endif // HAS_NCCL_BF16_DATATYPE
};
inline bool isUnsupportedFloat8(at::ScalarType t) {
return (
t == at::ScalarType::Float8_e5m2fnuz ||
t == at::ScalarType::Float8_e4m3fnuz ||
t == at::ScalarType::Float8_e8m0fnu
#ifndef NCCL_SUPPORTS_FP8
|| t == at::ScalarType::Float8_e5m2 || t == at::ScalarType::Float8_e4m3fn
#endif
);
}
// Helper function that gets the data type and issues error if not supported
ncclDataType_t getNcclDataType(at::ScalarType type) {
auto it = ncclDataType.find(type);
@ -4111,8 +4129,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
auto tensor = tensors.back();
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
!isUnsupportedFloat8(tensor.scalar_type()),
"Unsupported Float8 type for NCCL reduction");
#ifdef IS_NCCLX
tensor = tensor.coalesce();
at::Tensor outputTensor =
@ -4231,8 +4249,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
}
}
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
!isUnsupportedFloat8(tensor.scalar_type()),
"Unsupported Float8 type for NCCL reduction");
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -4260,8 +4278,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
const AllreduceCoalescedOptions& opts) {
auto total_numel = check_gpu_tensors_same_device(tensors);
TORCH_CHECK(
!isFloat8Type(tensors.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
!isUnsupportedFloat8(tensors.back().scalar_type()),
"Unsupported Float8 type for NCCL reduction");
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
@ -4656,8 +4674,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
check_gpu_single_tensor(outputTensor);
auto inputTensors_ = inputTensors.back();
TORCH_CHECK(
!isFloat8Type(outputTensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
!isUnsupportedFloat8(outputTensor.scalar_type()),
"Unsupported Float8 type for NCCL reduction");
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
@ -4760,8 +4778,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
const auto& tensor = outputTensor;
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
!isUnsupportedFloat8(tensor.scalar_type()),
"Unsupported Float8 type for NCCL reduction");
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -4819,8 +4837,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& inputs,
const ReduceScatterOptions& opts) {
TORCH_CHECK(
!isFloat8Type(inputs.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
!isUnsupportedFloat8(inputs.back().scalar_type()),
"Unsupported Float8 type for NCCL reduction");
RECORD_PARAM_COMMS_DATA(
std::make_tuple(