mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555 Approved by: https://github.com/ezyang ghstack dependencies: #144551, #144554
85 lines
2.3 KiB
Python
85 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import contextlib
|
|
import copyreg
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
import torch.distributed.rpc as rpc
|
|
import torch.multiprocessing.reductions as TorchMpReductions
|
|
from torch import multiprocessing
|
|
from torch.distributed.rpc.api import _use_rpc_pickler
|
|
from torch.distributed.rpc.internal import _InternalRPCPickler
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def fs_sharing():
|
|
prev_strategy = multiprocessing.get_sharing_strategy()
|
|
multiprocessing.set_sharing_strategy("file_system")
|
|
try:
|
|
yield
|
|
finally:
|
|
multiprocessing.set_sharing_strategy(prev_strategy)
|
|
|
|
|
|
class ShareMemoryRPCPickler(_InternalRPCPickler):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._dispatch_table
|
|
# pyre-fixme[4]: Attribute must be annotated.
|
|
self._dispatch_table = copyreg.dispatch_table.copy()
|
|
|
|
for t in torch._storage_classes:
|
|
self._dispatch_table[t] = TorchMpReductions.reduce_storage
|
|
|
|
for t in torch._tensor_classes:
|
|
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
|
|
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
|
|
self._dispatch_table[torch.nn.parameter.Parameter] = (
|
|
TorchMpReductions.reduce_tensor
|
|
)
|
|
|
|
|
|
def worker_loop(a):
|
|
rpc.init_rpc("worker1", rank=1, world_size=2)
|
|
rpc.shutdown()
|
|
|
|
|
|
def worker_fn(m):
|
|
pass
|
|
|
|
|
|
class TestRPCPickler(TestCase):
|
|
def test_case(self):
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "29500"
|
|
|
|
with fs_sharing():
|
|
r = multiprocessing.spawn(worker_loop, join=False)
|
|
|
|
try:
|
|
with _use_rpc_pickler(ShareMemoryRPCPickler()):
|
|
rpc.init_rpc("worker0", rank=0, world_size=2)
|
|
m = torch.nn.Linear(1, 2)
|
|
m.share_memory()
|
|
rref = rpc.remote("worker1", worker_fn, args=(m,))
|
|
|
|
rref.to_here()
|
|
finally:
|
|
rpc.shutdown()
|
|
r.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|