Revert "[SymmetricMemory] improve the API for stream_write_value32 (#139934)"

This reverts commit 2f3a5a15ef701ffab9a880cf822ff8e5224a4b33.

Reverted https://github.com/pytorch/pytorch/pull/139934 on behalf of https://github.com/malfet due to Broke distributed tests, see https://github.com/pytorch/pytorch/actions/runs/11770673088/job/32784210441 ([comment](https://github.com/pytorch/pytorch/pull/139934#issuecomment-2468641512))
This commit is contained in:
PyTorch MergeBot
2024-11-11 17:02:05 +00:00
parent 2fe110ff3a
commit 5f4a21dc58
8 changed files with 63 additions and 112 deletions

View File

@ -19,7 +19,7 @@ from torch.distributed._symmetric_memory import (
restride_A_for_fused_matmul_reduce_scatter,
restride_A_shard_for_fused_all_gather_matmul,
)
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
@ -624,6 +624,30 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_stream_write_value(self):
self._init_process()
group_name = dist.group.WORLD.group_name
t = _SymmetricMemory.empty_strided_p2p(
size=(64,),
stride=(1,),
dtype=torch.float32,
device=self.device,
group_name=group_name,
).fill_(self.rank + 42)
symm_mem = _SymmetricMemory.rendezvous(t)
tensor = torch.zeros(4, dtype=torch.uint32, device=self.device)
expect = torch.tril(torch.ones(4, 4, device=self.device)).to(torch.uint32)
for i in range(4):
symm_mem.stream_write_value32(
int(tensor.data_ptr()) + i * tensor.element_size(), 1
)
torch.testing.assert_close(tensor, expect[i])
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
@ -880,26 +904,6 @@ class LoweringTest(MultiProcessTestCase):
class SymmMemSingleProcTest(TestCase):
@skipIfRocm
@requires_cuda
@skipIf(
_get_torch_cuda_version() < (12, 0),
"stream_write_value32 currently only supports cuda version>=12.0",
)
def test_stream_write_value32(self):
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")
expect = torch.tril(torch.ones(4, 4, device="cuda")).to(torch.uint32)
for i in range(4):
_SymmetricMemory.stream_write_value32(tensor, i, 1)
torch.testing.assert_close(tensor, expect[i])
with self.assertRaises(RuntimeError):
_SymmetricMemory.stream_write_value32(tensor, offset=-1, val=1)
with self.assertRaises(RuntimeError):
_SymmetricMemory.stream_write_value32(tensor, offset=0, val=4294967296)
@skipIfRocm
@requires_cuda
def test_memset32(self):

View File

@ -671,11 +671,4 @@ class _SymmetricMemory:
def barrier(self, channel: int = 0) -> None: ...
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
@staticmethod
def memset32(
tensor: torch.Tensor, offset: int, val: int, count: int
) -> torch.Tensor: ...
@staticmethod
def stream_write_value32(
tensor: torch.Tensor, offset: int, val: int
) -> torch.Tensor: ...
def stream_write_value32(self, addr: int, val: int) -> None: ...

View File

@ -587,6 +587,24 @@ int CUDASymmetricMemory::get_world_size() {
return world_size_;
}
void CUDASymmetricMemory::stream_write_value32(uintptr_t addr, uint32_t val) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
// According to the documentation of CUstreamWriteValue_flags,
// cuStreamWriteValue32 will provide a memory fence before the write, which
// has similar semantics to __threadfence_system() but is scoped to the
// stream rather than a CUDA thread.
driver_api->cuStreamWriteValue32_(
at::cuda::getCurrentCUDAStream(),
reinterpret_cast<CUdeviceptr>((void*)addr),
val,
0);
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
void* CUDASymmetricMemoryAllocator::alloc(
size_t size,
int device_idx,

View File

@ -57,6 +57,8 @@ class CUDASymmetricMemory : public SymmetricMemory {
int get_rank() override;
int get_world_size() override;
void stream_write_value32(uintptr_t addr, uint32_t val) override;
private:
std::vector<HandleType> handles_;
size_t block_size_;

View File

@ -513,16 +513,8 @@ at::Tensor memset32_(
"symm_mem::memset32_: input must be a flat, contiguous uint32 tensor.");
TORCH_CHECK(
offset >= 0,
"symm_mem::memset32_: offset must be greater than or equal to 0 (got ",
offset,
")");
TORCH_CHECK(
count > 0,
"symm_mem::memset32_: count must be a positive integer (got ",
count,
")");
offset > 0 && count > 0,
"symm_mem::memset32_: offset and count must be positive integers.");
TORCH_CHECK(
val >= 0 &&
@ -555,59 +547,6 @@ at::Tensor memset32_(
return input;
}
at::Tensor stream_write_value32_(
at::Tensor& input,
int64_t offset,
int64_t val) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
TORCH_CHECK(
input.dim() == 1 && input.is_contiguous() &&
input.scalar_type() == c10::ScalarType::UInt32,
"symm_mem::stream_write_value32_: input must be a flat, contiguous "
"uint32 tensor.");
TORCH_CHECK(
offset >= 0,
"symm_mem::stream_write_value32_: offset must be greater than or "
"equal to 0 (got ",
offset,
")");
TORCH_CHECK(
val >= 0 &&
static_cast<size_t>(val) <= std::numeric_limits<uint32_t>::max(),
"symm_mem::stream_write_value32_: "
"val must be in the range of [0, 4294967295] (uint32_t).")
auto element_size = c10::elementSize(input.scalar_type());
TORCH_CHECK(
offset < input.numel(),
"symm_mem::stream_write_value32_: offset (",
offset,
") exceeded the numel of the input (",
input.numel(),
")");
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
c10::cuda::CUDAGuard guard(input.device());
auto driver_api = c10::cuda::DriverAPI::get();
// According to the documentation of CUstreamWriteValue_flags,
// cuStreamWriteValue32 will provide a memory fence before the write, which
// has similar semantics to __threadfence_system() but is scoped to the
// stream rather than a CUDA thread.
C10_CUDA_DRIVER_CHECK(driver_api->cuStreamWriteValue32_(
at::cuda::getCurrentCUDAStream(),
reinterpret_cast<CUdeviceptr>(addr),
val,
0));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
return input;
}
} // namespace
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
@ -672,11 +611,6 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
{at::Tag::pt2_compliant_tag});
#endif
m.def(
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)",
torch::dispatch(c10::DispatchKey::CUDA, ::stream_write_value32_),
{at::Tag::pt2_compliant_tag});
m.def(
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)",
torch::dispatch(c10::DispatchKey::CUDA, ::memset32_),

View File

@ -71,6 +71,8 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
virtual int get_rank() = 0;
virtual int get_world_size() = 0;
virtual void stream_write_value32(uintptr_t addr, uint32_t val) = 0;
};
class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {

View File

@ -1118,21 +1118,13 @@ This class does not support ``__members__`` property.)");
py::arg("src_rank"),
py::arg("channel") = 0,
py::arg("timeout_ms") = 0)
.def(
"stream_write_value32",
&SymmetricMemory::stream_write_value32,
py::arg("addr"),
py::arg("val"))
// Util functions that are often used together with symmetric memory but
// not necessarily directly on symmetric memory.
.def_static(
"stream_write_value32",
[](at::Tensor& input, int64_t offset, int64_t val) {
// The range of `val` is checked inside the op
auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("symm_mem::stream_write_value32_", "")
.typed<at::Tensor(at::Tensor&, int64_t, int64_t)>();
return op.call(input, offset, val);
},
py::arg("input"),
py::arg("offset"),
py::arg("val"))
.def_static(
"memset32",
[](at::Tensor& input, int64_t offset, int64_t val, int64_t count) {

View File

@ -690,7 +690,10 @@ def _fused_all_gather_matmul_native(
A_shards = A.chunk(world_size)
A_shards[rank].copy_(A_shard)
_SymmetricMemory.stream_write_value32(A_signals, rank, 1)
symm_mem.stream_write_value32(
int(A_signals.data_ptr()) + rank * A_signals.element_size(),
1,
)
out = torch.ops.symm_mem._async_input_mm(A, B, A_signals, rank)
for step in range(1, world_size):
@ -699,7 +702,10 @@ def _fused_all_gather_matmul_native(
with torch.cuda.stream(backend_stream):
A_shards[src_rank].copy_(src_buf)
# cuStreamWriteValue32 issues a system level fence before the write
_SymmetricMemory.stream_write_value32(A_signals, src_rank, 1)
symm_mem.stream_write_value32(
int(A_signals.data_ptr()) + src_rank * A_signals.element_size(),
1,
)
current_stream.wait_stream(backend_stream)
backend_stream.wait_stream(current_stream)