mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PGNCCL] Add FP8 support (#152706)
NCCL added support for `Float8e4m3` and `Float8e5m2` in 2.24. NVIDIA GPUs does not seem to support the following "no negative zero" versions: `Float8_e4m3fnuz` and `Float8_e5m2fnuz`, see https://onnx.ai/onnx/technical/float8.html. So we continue to error out for these two upon a reduction op. Test plan: - test_allreduce_float8 - test_reduce_scatter_float8 Resolves #148344 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152706 Approved by: https://github.com/d4l3k, https://github.com/eqy, https://github.com/fduwjj, https://github.com/cyyever
This commit is contained in:
@ -3273,27 +3273,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce_coalesced_nccl_float8_errors(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
process_group = c10d.distributed_c10d._get_default_group()
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
tensors = [
|
||||
torch.full(
|
||||
(60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
|
||||
).to(torch.float8_e4m3fn)
|
||||
for i in range(5)
|
||||
]
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions",
|
||||
):
|
||||
torch.distributed.all_reduce_coalesced(tensors, group=process_group)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_reduce_coalesced_manager_nccl(self):
|
||||
@ -3685,56 +3664,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
|
||||
self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_reduce_scatter_base_k_float8_errors(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
output_tensor = (
|
||||
torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank)
|
||||
)
|
||||
input_tensors = (
|
||||
torch.arange(self.world_size * 2, dtype=torch.float32)
|
||||
.to(torch.float8_e4m3fn)
|
||||
.to(self.rank)
|
||||
)
|
||||
input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions",
|
||||
):
|
||||
dist.reduce_scatter_tensor(output_tensor, input_tensors)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_reduce_scatter_tensor_coalesced_float8_errors(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank)
|
||||
input_tensors = [
|
||||
torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank)
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions",
|
||||
):
|
||||
with dist._coalescing_manager():
|
||||
for i in range(self.world_size):
|
||||
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
|
||||
self.assertEqual(output_tensors, input_tensors[self.rank])
|
||||
|
||||
|
||||
class SetDeviceMethod(Enum):
|
||||
TORCH_CUDA_SET = auto() # torch.cuda.set_device
|
||||
|
@ -28,6 +28,8 @@ from torch.testing._internal.common_distributed import (
|
||||
init_multigpu_helper,
|
||||
MultiProcContinousTest,
|
||||
requires_nccl,
|
||||
requires_nccl_version,
|
||||
sm_is_or_higher_than,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -243,6 +245,24 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||
with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with NCCL"):
|
||||
allreduce(tensors, op)
|
||||
|
||||
@requires_nccl_version((2, 24), "Need NCCL 2.24+ for Float8")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_allreduce_float8(self):
|
||||
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
|
||||
if not sm_is_or_higher_than(device, 9, 0):
|
||||
self.skipTest("Float8 requires sm >= 90")
|
||||
|
||||
numel = 1024
|
||||
tensor = torch.ones(numel, dtype=torch.float32, device=device).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
dist.all_reduce(tensor)
|
||||
|
||||
expected = (
|
||||
torch.empty_like(tensor).fill_(self.world_size).to(torch.float8_e4m3fn)
|
||||
)
|
||||
torch.testing.assert_close(tensor, expected)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_alltoall_ops_with_cudafree_race(self):
|
||||
@ -892,6 +912,27 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
|
||||
# Verification
|
||||
self.assertEqual(output_t[0], self.rank * self.world_size)
|
||||
|
||||
@requires_nccl_version((2, 24), "Need NCCL 2.24+ for Float8")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_reduce_scatter_float8(self):
|
||||
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
|
||||
if not sm_is_or_higher_than(device, 9, 0):
|
||||
self.skipTest("Float8 requires sm >= 90")
|
||||
|
||||
numel = 1024
|
||||
output_tensor = torch.zeros(numel, dtype=torch.float32, device=device).to(
|
||||
torch.float8_e5m2
|
||||
)
|
||||
input_tensor = torch.ones(
|
||||
self.world_size * numel, dtype=torch.float32, device=device
|
||||
).to(torch.float8_e5m2)
|
||||
dist.reduce_scatter_tensor(output_tensor, input_tensor)
|
||||
|
||||
expected = (
|
||||
torch.empty_like(output_tensor).fill_(self.world_size).to(torch.float8_e5m2)
|
||||
)
|
||||
torch.testing.assert_close(output_tensor, expected)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_barrier(self):
|
||||
|
@ -70,6 +70,10 @@ static_assert(
|
||||
#define NCCL_HAS_QOS
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 24, 0)
|
||||
#define NCCL_SUPPORTS_FP8
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
|
@ -66,8 +66,15 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
|
||||
{at::kLong, ncclInt64},
|
||||
{at::kHalf, ncclHalf},
|
||||
{at::kBool, ncclUint8},
|
||||
#ifdef NCCL_SUPPORTS_FP8
|
||||
{at::kFloat8_e5m2, ncclFloat8e5m2},
|
||||
{at::kFloat8_e4m3fn, ncclFloat8e4m3},
|
||||
#else
|
||||
{at::kFloat8_e5m2, ncclUint8},
|
||||
{at::kFloat8_e4m3fn, ncclUint8},
|
||||
#endif
|
||||
// NVIDIA GPUs does not support the UZ version standing for "no negative
|
||||
// zero". See https://onnx.ai/onnx/technical/float8.html
|
||||
{at::kFloat8_e4m3fnuz, ncclUint8},
|
||||
{at::kFloat8_e5m2fnuz, ncclUint8},
|
||||
#if HAS_NCCL_BF16_DATATYPE
|
||||
@ -75,6 +82,17 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
|
||||
#endif // HAS_NCCL_BF16_DATATYPE
|
||||
};
|
||||
|
||||
inline bool isUnsupportedFloat8(at::ScalarType t) {
|
||||
return (
|
||||
t == at::ScalarType::Float8_e5m2fnuz ||
|
||||
t == at::ScalarType::Float8_e4m3fnuz ||
|
||||
t == at::ScalarType::Float8_e8m0fnu
|
||||
#ifndef NCCL_SUPPORTS_FP8
|
||||
|| t == at::ScalarType::Float8_e5m2 || t == at::ScalarType::Float8_e4m3fn
|
||||
#endif
|
||||
);
|
||||
}
|
||||
|
||||
// Helper function that gets the data type and issues error if not supported
|
||||
ncclDataType_t getNcclDataType(at::ScalarType type) {
|
||||
auto it = ncclDataType.find(type);
|
||||
@ -4111,8 +4129,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
|
||||
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
|
||||
auto tensor = tensors.back();
|
||||
TORCH_CHECK(
|
||||
!isFloat8Type(tensor.scalar_type()),
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions");
|
||||
!isUnsupportedFloat8(tensor.scalar_type()),
|
||||
"Unsupported Float8 type for NCCL reduction");
|
||||
#ifdef IS_NCCLX
|
||||
tensor = tensor.coalesce();
|
||||
at::Tensor outputTensor =
|
||||
@ -4231,8 +4249,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!isFloat8Type(tensor.scalar_type()),
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions");
|
||||
!isUnsupportedFloat8(tensor.scalar_type()),
|
||||
"Unsupported Float8 type for NCCL reduction");
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
std::make_tuple(
|
||||
static_cast<int64_t>(seqCollective_) + 1,
|
||||
@ -4260,8 +4278,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
|
||||
const AllreduceCoalescedOptions& opts) {
|
||||
auto total_numel = check_gpu_tensors_same_device(tensors);
|
||||
TORCH_CHECK(
|
||||
!isFloat8Type(tensors.back().scalar_type()),
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions");
|
||||
!isUnsupportedFloat8(tensors.back().scalar_type()),
|
||||
"Unsupported Float8 type for NCCL reduction");
|
||||
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
std::make_tuple(
|
||||
@ -4656,8 +4674,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
||||
check_gpu_single_tensor(outputTensor);
|
||||
auto inputTensors_ = inputTensors.back();
|
||||
TORCH_CHECK(
|
||||
!isFloat8Type(outputTensor.scalar_type()),
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions");
|
||||
!isUnsupportedFloat8(outputTensor.scalar_type()),
|
||||
"Unsupported Float8 type for NCCL reduction");
|
||||
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
std::make_tuple(
|
||||
@ -4760,8 +4778,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
|
||||
|
||||
const auto& tensor = outputTensor;
|
||||
TORCH_CHECK(
|
||||
!isFloat8Type(tensor.scalar_type()),
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions");
|
||||
!isUnsupportedFloat8(tensor.scalar_type()),
|
||||
"Unsupported Float8 type for NCCL reduction");
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
std::make_tuple(
|
||||
static_cast<int64_t>(seqCollective_) + 1,
|
||||
@ -4819,8 +4837,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
const ReduceScatterOptions& opts) {
|
||||
TORCH_CHECK(
|
||||
!isFloat8Type(inputs.back().scalar_type()),
|
||||
"Float8 dtypes are not currenlty supported for NCCL reductions");
|
||||
!isUnsupportedFloat8(inputs.back().scalar_type()),
|
||||
"Unsupported Float8 type for NCCL reduction");
|
||||
|
||||
RECORD_PARAM_COMMS_DATA(
|
||||
std::make_tuple(
|
||||
|
Reference in New Issue
Block a user