[rpc] Switch RPC agent check to TORCH_CHECK and add more descriptive error (#67882)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67882

I ran into a hard-to-interpret error message when trying to run the following script, which was missing an `init_rpc` call:

```
# $ torchrun --standalone --nnodes=1 --nproc_per_node=1 script.py
import os
rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])

import torch.distributed
# !!!!!! Uncomment the following and the script succeeds
# torch.distributed.rpc.init_rpc('worker', rank=rank, world_size=world_size)

import torch.distributed as dist
dist.init_process_group(backend='gloo')

import torchvision.models as models
import torch

rn50 = models.resnet50()
rn50.train()
rn50 = torch.nn.parallel.DistributedDataParallel(rn50)

from torch.distributed.rpc import RRef
from torch.distributed.optim import DistributedOptimizer

params = []
for param in rn50.parameters():
    params.append(RRef(param))

dist_optim = DistributedOptimizer(
        torch.optim.SGD,
        params,
        lr=0.05)

loss_func = torch.nn.CrossEntropyLoss()

with torch.distributed.autograd.context() as context_id:
    pred = rn50(torch.randn(50, 3, 224, 224))
    target = torch.randn(50, 1000).softmax(dim=1)
    loss = loss_func(pred, target)
    dist.autograd.backward(context_id, [loss])
    dist_optim.step(context_id)
```

Error:

```
Traceback (most recent call last):
  File "/xxx/torchrun_exp/script.py", line 23, in <module>
    params.append(RRef(param))
RuntimeError: agentINTERNAL ASSERT FAILED at "../torch/csrc/distributed/rpc/rpc_agent.cpp":237, please report a bug to PyTorch. Current RPC agent is not set!
```

Since this is a user-facing error, I've changed `TORCH_INTERNAL_ASSERT` to `TORCH_CHECK` and added a hint about how to resolve the issue. On the other hand, the fact that this was originally `TORCH_INTERNAL_ASSERT` may suggest that the author thought that this should be an internal-only error condition. If there is some other place that should be throwing an exception in this case that is failing, let me know and I can adapt the fix to change that location.

Question for reviewers:
* Is there a good test file where I can add a test for this error condition?

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D32190947

Pulled By: jamesr66a

fbshipit-source-id: 3621d755329fd524db68675c55b1daf20e716d43
This commit is contained in:
James Reed
2021-11-05 17:29:06 -07:00
committed by Facebook GitHub Bot
parent efdb17b984
commit 22afe82ce3
2 changed files with 29 additions and 1 deletions

View File

@ -234,7 +234,10 @@ bool RpcAgent::isCurrentRpcAgentSet() {
std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
std::shared_ptr<RpcAgent> agent = std::atomic_load(&currentRpcAgent_);
TORCH_INTERNAL_ASSERT(agent, "Current RPC agent is not set!");
TORCH_CHECK(
agent,
"Current RPC agent is not set! Did you initialize the RPC "
"framework (e.g. by calling `rpc.init_rpc`)?");
return agent;
}