Files
pytorch/torch/distributed/rpc/rref_proxy.py
Rohan Varma a5fb12d168 RRef proxy support for ScriptModule methods (#48339)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48339

Closes https://github.com/pytorch/pytorch/issues/48294
https://github.com/pytorch/pytorch/pull/48293 added creation and transfer of ScriptModule over RPC in python, but it did not work with ScriptModule.

This PR makes the above work with ScriptModule as per a discussion with mrshenli:
1) We remove the `hasattr()` check and just let Python throw the exception as it would when accessing the py function with `getattr`
2) We  condition on `issubclass(type, ScriptModule)` when checking if it is wrapped with async_function, because `ScriptModule` does not have getattr implemented (this is because ScriptModule forward/function is not a python function, it is a torchscript specific function):
```
torch/jit/_script.py", line 229, in __get__
    return self.__getattr__("forward")  # type: ignore
AttributeError: '_CachedForward' object has no attribute '__getattr__'
```
ghstack-source-id: 117631795

Test Plan: Modified ut

Reviewed By: wanchaol

Differential Revision: D25134423

fbshipit-source-id: 918ca88891c7b0531325f046b61f28947575cff0
2020-12-04 11:33:16 -08:00

41 lines
1.2 KiB
Python

from functools import partial
from . import functions
import torch
def _local_invoke(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
@functions.async_execution
def _local_invoke_async_execution(rref, func_name, args, kwargs):
return getattr(rref.local_value(), func_name)(*args, **kwargs)
def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs):
rref_type = rref._get_type()
_invoke_func = _local_invoke
# Bypass ScriptModules when checking for async function attribute.
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
rref_type, torch._C.ScriptModule
)
if not bypass_type:
func = getattr(rref_type, func_name)
if hasattr(func, "_wrapped_async_rpc_function"):
_invoke_func = _local_invoke_async_execution
return rpc_api(
rref.owner(),
_invoke_func,
args=(rref, func_name, args, kwargs)
)
class RRefProxy:
def __init__(self, rref, rpc_api):
self.rref = rref
self.rpc_api = rpc_api
def __getattr__(self, func_name):
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name)