Add strictness check and made tensors into leaves if input tensors were leaves (#77474)

I think this makes sense to do? Otherwise, if you call `backward()` in your traced function, you can't get gradients out of any tensors that should have been leaves.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77474
Approved by: https://github.com/ezyang
This commit is contained in:
Horace He
2022-05-21 01:16:39 +00:00
committed by PyTorch MergeBot
parent ffa3cce100
commit 50cadfae10
2 changed files with 31 additions and 11 deletions

View File

@ -702,13 +702,21 @@ class TestFXExperimental(JitTestCase):
torch.testing.assert_close(loaded(x), mttm(x))
def test_proxy_tensor(self):
def f(x):
def f_grad(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
inp = torch.randn(3, requires_grad=True)
torch.testing.assert_close(traced_graph(inp), f(inp))
def f_backward(x):
val = x.cos().cos().sum()
val.backward()
return x.grad
for f in [f_grad, f_backward]:
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
inp = torch.randn(3, requires_grad=True)
traced_graph_out = traced_graph(inp)
assert inp.grad is None
torch.testing.assert_close(traced_graph_out, f(inp))
def test_mode_tracing_factory_function(self):
def f(x):

View File

@ -15,7 +15,7 @@ from contextlib import contextmanager
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
aten = torch.ops.aten
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
@ -40,6 +40,11 @@ def decompose(decomposition_table):
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
# Checks whether we try to convert the tensor into a scalar
IS_STRICT = True
def enable_strict(val):
global IS_STRICT
IS_STRICT = val
def wrap_output(real_out, proxy_out):
def wrap_with_proxy(e, proxy):
@ -68,7 +73,8 @@ def proxy_call(func_overload, args, kwargs=None):
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
if func_overload == aten._local_scalar_dense.default:
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar.")
"It's likely that this is caused by data-dependent control flow or similar."
"Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")
def unwrap_proxy(e):
return e.proxy if isinstance(e, ProxyTensor) else e
@ -92,18 +98,24 @@ class ProxyTensor(torch.Tensor):
proxy: fx.Proxy
@staticmethod
def __new__(cls, elem, proxy):
def __new__(cls, elem, proxy, *, requires_grad=None):
# Hack to deal with super().__new__ not working for sparse tensors
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
if elem.is_sparse or requires_grad is not None:
r = torch.Tensor._make_subclass(cls, elem, requires_grad)
else:
r = super().__new__(cls, elem) # type: ignore[call-arg]
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
else:
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
r.proxy = proxy # type: ignore[attr-defined]
return r
def __deepcopy__(self, memo):
return self.clone()
def __repr__(self):
with no_dispatch():
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
@ -172,7 +184,7 @@ def wrap_key(f, inps):
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
with no_dispatch():
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf)
else:
flat_args[idx] = flat_inps[idx]