mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48339 Closes https://github.com/pytorch/pytorch/issues/48294 https://github.com/pytorch/pytorch/pull/48293 added creation and transfer of ScriptModule over RPC in python, but it did not work with ScriptModule. This PR makes the above work with ScriptModule as per a discussion with mrshenli: 1) We remove the `hasattr()` check and just let Python throw the exception as it would when accessing the py function with `getattr` 2) We condition on `issubclass(type, ScriptModule)` when checking if it is wrapped with async_function, because `ScriptModule` does not have getattr implemented (this is because ScriptModule forward/function is not a python function, it is a torchscript specific function): ``` torch/jit/_script.py", line 229, in __get__ return self.__getattr__("forward") # type: ignore AttributeError: '_CachedForward' object has no attribute '__getattr__' ``` ghstack-source-id: 117631795 Test Plan: Modified ut Reviewed By: wanchaol Differential Revision: D25134423 fbshipit-source-id: 918ca88891c7b0531325f046b61f28947575cff0
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
from functools import partial
|
|
|
|
from . import functions
|
|
|
|
import torch
|
|
|
|
def _local_invoke(rref, func_name, args, kwargs):
|
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
|
|
@functions.async_execution
|
|
def _local_invoke_async_execution(rref, func_name, args, kwargs):
|
|
return getattr(rref.local_value(), func_name)(*args, **kwargs)
|
|
|
|
def _invoke_rpc(rref, rpc_api, func_name, *args, **kwargs):
|
|
rref_type = rref._get_type()
|
|
|
|
_invoke_func = _local_invoke
|
|
# Bypass ScriptModules when checking for async function attribute.
|
|
bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
|
|
rref_type, torch._C.ScriptModule
|
|
)
|
|
if not bypass_type:
|
|
func = getattr(rref_type, func_name)
|
|
if hasattr(func, "_wrapped_async_rpc_function"):
|
|
_invoke_func = _local_invoke_async_execution
|
|
|
|
return rpc_api(
|
|
rref.owner(),
|
|
_invoke_func,
|
|
args=(rref, func_name, args, kwargs)
|
|
)
|
|
|
|
|
|
class RRefProxy:
|
|
def __init__(self, rref, rpc_api):
|
|
self.rref = rref
|
|
self.rpc_api = rpc_api
|
|
|
|
def __getattr__(self, func_name):
|
|
return partial(_invoke_rpc, self.rref, self.rpc_api, func_name)
|