[functorch] fixed some python key tracing issues

This commit is contained in:
Horace He
2021-05-04 02:12:12 -07:00
committed by Jon Janzen
parent 7224611cd9
commit 6ecf169a07

View File

@ -9,7 +9,7 @@ from torch.fx import Tracer, GraphModule
class PythonTensor(object):
def __init__(self, out, proxy):
if isinstance(out, torch.Tensor):
self.value = torch.empty_like(out)
self.value = torch.clone(out)
else:
self.value = torch.empty(out)
self.proxy = proxy
@ -40,8 +40,35 @@ class PythonKeyTracer(Tracer):
def __init__(self):
super().__init__()
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
return False
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
a call to an ``nn.Module`` instance.
By default, the behavior is to check if the called module is a leaf module
via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
the operations in its ``forward`` function.
This method can be overridden to--for example--create nested traced
GraphModules, or any other behavior you would want while tracing across
``Module`` boundaries.
Args:
m (Module): The module for which a call is being emitted
forward (Callable): The forward() method of the ``Module`` to be invoked
args (Tuple): args of the module callsite
kwargs (Dict): kwargs of the module callsite
Return:
The return value from the Module call. In the case that a ``call_module``
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
value was returned from the ``Module`` invocation.
"""
return forward(*args, **kwargs)
def module_getattr(self, attr, attr_val):
if isinstance(attr_val, torch.nn.Parameter):
@ -92,7 +119,7 @@ def wrap_key(f, inps):
assert(len(flat_args) == len(flat_inps))
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx].shape, arg))
flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx], arg))
else:
flat_args[idx] = flat_inps[idx]
tree_args = pytree.tree_unflatten(flat_args, args_spec)