From 6ecf169a0779c43526c1a25637a0bfe067a06274 Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 4 May 2021 02:12:12 -0700 Subject: [PATCH] [functorch] fixed some python key tracing issues --- functorch/functorch/_src/python_key.py | 35 +++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py index 8cd5b3f9e2fd..00f4e444607d 100644 --- a/functorch/functorch/_src/python_key.py +++ b/functorch/functorch/_src/python_key.py @@ -9,7 +9,7 @@ from torch.fx import Tracer, GraphModule class PythonTensor(object): def __init__(self, out, proxy): if isinstance(out, torch.Tensor): - self.value = torch.empty_like(out) + self.value = torch.clone(out) else: self.value = torch.empty(out) self.proxy = proxy @@ -40,8 +40,35 @@ class PythonKeyTracer(Tracer): def __init__(self): super().__init__() - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: - return False + + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + return forward(*args, **kwargs) def module_getattr(self, attr, attr_val): if isinstance(attr_val, torch.nn.Parameter): @@ -92,7 +119,7 @@ def wrap_key(f, inps): assert(len(flat_args) == len(flat_inps)) for idx, arg in enumerate(flat_args): if isinstance(flat_inps[idx], torch.Tensor): - flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx].shape, arg)) + flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx], arg)) else: flat_args[idx] = flat_inps[idx] tree_args = pytree.tree_unflatten(flat_args, args_spec)