mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[SymmetricMemory] improve the API for stream_write_value32 (#139934)"
This reverts commit 2f3a5a15ef701ffab9a880cf822ff8e5224a4b33. Reverted https://github.com/pytorch/pytorch/pull/139934 on behalf of https://github.com/malfet due to Broke distributed tests, see https://github.com/pytorch/pytorch/actions/runs/11770673088/job/32784210441 ([comment](https://github.com/pytorch/pytorch/pull/139934#issuecomment-2468641512))
This commit is contained in:
@ -19,7 +19,7 @@ from torch.distributed._symmetric_memory import (
|
||||
restride_A_for_fused_matmul_reduce_scatter,
|
||||
restride_A_shard_for_fused_all_gather_matmul,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
|
||||
from torch.testing._internal.common_cuda import SM90OrLater
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
@ -624,6 +624,30 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_stream_write_value(self):
|
||||
self._init_process()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(64,),
|
||||
stride=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
group_name=group_name,
|
||||
).fill_(self.rank + 42)
|
||||
symm_mem = _SymmetricMemory.rendezvous(t)
|
||||
|
||||
tensor = torch.zeros(4, dtype=torch.uint32, device=self.device)
|
||||
expect = torch.tril(torch.ones(4, 4, device=self.device)).to(torch.uint32)
|
||||
|
||||
for i in range(4):
|
||||
symm_mem.stream_write_value32(
|
||||
int(tensor.data_ptr()) + i * tensor.element_size(), 1
|
||||
)
|
||||
torch.testing.assert_close(tensor, expect[i])
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
@ -880,26 +904,6 @@ class LoweringTest(MultiProcessTestCase):
|
||||
|
||||
|
||||
class SymmMemSingleProcTest(TestCase):
|
||||
@skipIfRocm
|
||||
@requires_cuda
|
||||
@skipIf(
|
||||
_get_torch_cuda_version() < (12, 0),
|
||||
"stream_write_value32 currently only supports cuda version>=12.0",
|
||||
)
|
||||
def test_stream_write_value32(self):
|
||||
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")
|
||||
expect = torch.tril(torch.ones(4, 4, device="cuda")).to(torch.uint32)
|
||||
|
||||
for i in range(4):
|
||||
_SymmetricMemory.stream_write_value32(tensor, i, 1)
|
||||
torch.testing.assert_close(tensor, expect[i])
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.stream_write_value32(tensor, offset=-1, val=1)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.stream_write_value32(tensor, offset=0, val=4294967296)
|
||||
|
||||
@skipIfRocm
|
||||
@requires_cuda
|
||||
def test_memset32(self):
|
||||
|
@ -671,11 +671,4 @@ class _SymmetricMemory:
|
||||
def barrier(self, channel: int = 0) -> None: ...
|
||||
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
|
||||
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
|
||||
@staticmethod
|
||||
def memset32(
|
||||
tensor: torch.Tensor, offset: int, val: int, count: int
|
||||
) -> torch.Tensor: ...
|
||||
@staticmethod
|
||||
def stream_write_value32(
|
||||
tensor: torch.Tensor, offset: int, val: int
|
||||
) -> torch.Tensor: ...
|
||||
def stream_write_value32(self, addr: int, val: int) -> None: ...
|
||||
|
@ -587,6 +587,24 @@ int CUDASymmetricMemory::get_world_size() {
|
||||
return world_size_;
|
||||
}
|
||||
|
||||
void CUDASymmetricMemory::stream_write_value32(uintptr_t addr, uint32_t val) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
// According to the documentation of CUstreamWriteValue_flags,
|
||||
// cuStreamWriteValue32 will provide a memory fence before the write, which
|
||||
// has similar semantics to __threadfence_system() but is scoped to the
|
||||
// stream rather than a CUDA thread.
|
||||
driver_api->cuStreamWriteValue32_(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
reinterpret_cast<CUdeviceptr>((void*)addr),
|
||||
val,
|
||||
0);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
void* CUDASymmetricMemoryAllocator::alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
|
@ -57,6 +57,8 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||
int get_rank() override;
|
||||
int get_world_size() override;
|
||||
|
||||
void stream_write_value32(uintptr_t addr, uint32_t val) override;
|
||||
|
||||
private:
|
||||
std::vector<HandleType> handles_;
|
||||
size_t block_size_;
|
||||
|
@ -513,16 +513,8 @@ at::Tensor memset32_(
|
||||
"symm_mem::memset32_: input must be a flat, contiguous uint32 tensor.");
|
||||
|
||||
TORCH_CHECK(
|
||||
offset >= 0,
|
||||
"symm_mem::memset32_: offset must be greater than or equal to 0 (got ",
|
||||
offset,
|
||||
")");
|
||||
|
||||
TORCH_CHECK(
|
||||
count > 0,
|
||||
"symm_mem::memset32_: count must be a positive integer (got ",
|
||||
count,
|
||||
")");
|
||||
offset > 0 && count > 0,
|
||||
"symm_mem::memset32_: offset and count must be positive integers.");
|
||||
|
||||
TORCH_CHECK(
|
||||
val >= 0 &&
|
||||
@ -555,59 +547,6 @@ at::Tensor memset32_(
|
||||
return input;
|
||||
}
|
||||
|
||||
at::Tensor stream_write_value32_(
|
||||
at::Tensor& input,
|
||||
int64_t offset,
|
||||
int64_t val) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
TORCH_CHECK(
|
||||
input.dim() == 1 && input.is_contiguous() &&
|
||||
input.scalar_type() == c10::ScalarType::UInt32,
|
||||
"symm_mem::stream_write_value32_: input must be a flat, contiguous "
|
||||
"uint32 tensor.");
|
||||
|
||||
TORCH_CHECK(
|
||||
offset >= 0,
|
||||
"symm_mem::stream_write_value32_: offset must be greater than or "
|
||||
"equal to 0 (got ",
|
||||
offset,
|
||||
")");
|
||||
|
||||
TORCH_CHECK(
|
||||
val >= 0 &&
|
||||
static_cast<size_t>(val) <= std::numeric_limits<uint32_t>::max(),
|
||||
"symm_mem::stream_write_value32_: "
|
||||
"val must be in the range of [0, 4294967295] (uint32_t).")
|
||||
|
||||
auto element_size = c10::elementSize(input.scalar_type());
|
||||
TORCH_CHECK(
|
||||
offset < input.numel(),
|
||||
"symm_mem::stream_write_value32_: offset (",
|
||||
offset,
|
||||
") exceeded the numel of the input (",
|
||||
input.numel(),
|
||||
")");
|
||||
|
||||
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
|
||||
|
||||
c10::cuda::CUDAGuard guard(input.device());
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
// According to the documentation of CUstreamWriteValue_flags,
|
||||
// cuStreamWriteValue32 will provide a memory fence before the write, which
|
||||
// has similar semantics to __threadfence_system() but is scoped to the
|
||||
// stream rather than a CUDA thread.
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuStreamWriteValue32_(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
reinterpret_cast<CUdeviceptr>(addr),
|
||||
val,
|
||||
0));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
return input;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
@ -672,11 +611,6 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
#endif
|
||||
m.def(
|
||||
"stream_write_value32_(Tensor(a!) input, int offset, int val) -> Tensor(a!)",
|
||||
torch::dispatch(c10::DispatchKey::CUDA, ::stream_write_value32_),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)",
|
||||
torch::dispatch(c10::DispatchKey::CUDA, ::memset32_),
|
||||
|
@ -71,6 +71,8 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
|
||||
|
||||
virtual int get_rank() = 0;
|
||||
virtual int get_world_size() = 0;
|
||||
|
||||
virtual void stream_write_value32(uintptr_t addr, uint32_t val) = 0;
|
||||
};
|
||||
|
||||
class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
||||
|
@ -1118,21 +1118,13 @@ This class does not support ``__members__`` property.)");
|
||||
py::arg("src_rank"),
|
||||
py::arg("channel") = 0,
|
||||
py::arg("timeout_ms") = 0)
|
||||
.def(
|
||||
"stream_write_value32",
|
||||
&SymmetricMemory::stream_write_value32,
|
||||
py::arg("addr"),
|
||||
py::arg("val"))
|
||||
// Util functions that are often used together with symmetric memory but
|
||||
// not necessarily directly on symmetric memory.
|
||||
.def_static(
|
||||
"stream_write_value32",
|
||||
[](at::Tensor& input, int64_t offset, int64_t val) {
|
||||
// The range of `val` is checked inside the op
|
||||
auto op =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("symm_mem::stream_write_value32_", "")
|
||||
.typed<at::Tensor(at::Tensor&, int64_t, int64_t)>();
|
||||
return op.call(input, offset, val);
|
||||
},
|
||||
py::arg("input"),
|
||||
py::arg("offset"),
|
||||
py::arg("val"))
|
||||
.def_static(
|
||||
"memset32",
|
||||
[](at::Tensor& input, int64_t offset, int64_t val, int64_t count) {
|
||||
|
@ -690,7 +690,10 @@ def _fused_all_gather_matmul_native(
|
||||
A_shards = A.chunk(world_size)
|
||||
|
||||
A_shards[rank].copy_(A_shard)
|
||||
_SymmetricMemory.stream_write_value32(A_signals, rank, 1)
|
||||
symm_mem.stream_write_value32(
|
||||
int(A_signals.data_ptr()) + rank * A_signals.element_size(),
|
||||
1,
|
||||
)
|
||||
|
||||
out = torch.ops.symm_mem._async_input_mm(A, B, A_signals, rank)
|
||||
for step in range(1, world_size):
|
||||
@ -699,7 +702,10 @@ def _fused_all_gather_matmul_native(
|
||||
with torch.cuda.stream(backend_stream):
|
||||
A_shards[src_rank].copy_(src_buf)
|
||||
# cuStreamWriteValue32 issues a system level fence before the write
|
||||
_SymmetricMemory.stream_write_value32(A_signals, src_rank, 1)
|
||||
symm_mem.stream_write_value32(
|
||||
int(A_signals.data_ptr()) + src_rank * A_signals.element_size(),
|
||||
1,
|
||||
)
|
||||
|
||||
current_stream.wait_stream(backend_stream)
|
||||
backend_stream.wait_stream(current_stream)
|
||||
|
Reference in New Issue
Block a user