mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] fixed some python key tracing issues
This commit is contained in:
@ -9,7 +9,7 @@ from torch.fx import Tracer, GraphModule
|
||||
class PythonTensor(object):
|
||||
def __init__(self, out, proxy):
|
||||
if isinstance(out, torch.Tensor):
|
||||
self.value = torch.empty_like(out)
|
||||
self.value = torch.clone(out)
|
||||
else:
|
||||
self.value = torch.empty(out)
|
||||
self.proxy = proxy
|
||||
@ -40,8 +40,35 @@ class PythonKeyTracer(Tracer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
return False
|
||||
|
||||
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Method that specifies the behavior of this ``Tracer`` when it encounters
|
||||
a call to an ``nn.Module`` instance.
|
||||
|
||||
By default, the behavior is to check if the called module is a leaf module
|
||||
via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
|
||||
``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
|
||||
the operations in its ``forward`` function.
|
||||
|
||||
This method can be overridden to--for example--create nested traced
|
||||
GraphModules, or any other behavior you would want while tracing across
|
||||
``Module`` boundaries.
|
||||
|
||||
Args:
|
||||
|
||||
m (Module): The module for which a call is being emitted
|
||||
forward (Callable): The forward() method of the ``Module`` to be invoked
|
||||
args (Tuple): args of the module callsite
|
||||
kwargs (Dict): kwargs of the module callsite
|
||||
|
||||
Return:
|
||||
|
||||
The return value from the Module call. In the case that a ``call_module``
|
||||
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
|
||||
value was returned from the ``Module`` invocation.
|
||||
"""
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
def module_getattr(self, attr, attr_val):
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
@ -92,7 +119,7 @@ def wrap_key(f, inps):
|
||||
assert(len(flat_args) == len(flat_inps))
|
||||
for idx, arg in enumerate(flat_args):
|
||||
if isinstance(flat_inps[idx], torch.Tensor):
|
||||
flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx].shape, arg))
|
||||
flat_args[idx] = addPythonKey(PythonTensor(flat_inps[idx], arg))
|
||||
else:
|
||||
flat_args[idx] = flat_inps[idx]
|
||||
tree_args = pytree.tree_unflatten(flat_args, args_spec)
|
||||
|
Reference in New Issue
Block a user