mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add strictness check and made tensors into leaves if input tensors were leaves (#77474)
I think this makes sense to do? Otherwise, if you call `backward()` in your traced function, you can't get gradients out of any tensors that should have been leaves. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77474 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffa3cce100
commit
50cadfae10
@ -702,13 +702,21 @@ class TestFXExperimental(JitTestCase):
|
||||
torch.testing.assert_close(loaded(x), mttm(x))
|
||||
|
||||
def test_proxy_tensor(self):
|
||||
def f(x):
|
||||
def f_grad(x):
|
||||
val = x.cos().cos().sum()
|
||||
return torch.autograd.grad(val, x)
|
||||
|
||||
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
|
||||
inp = torch.randn(3, requires_grad=True)
|
||||
torch.testing.assert_close(traced_graph(inp), f(inp))
|
||||
def f_backward(x):
|
||||
val = x.cos().cos().sum()
|
||||
val.backward()
|
||||
return x.grad
|
||||
|
||||
for f in [f_grad, f_backward]:
|
||||
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
|
||||
inp = torch.randn(3, requires_grad=True)
|
||||
traced_graph_out = traced_graph(inp)
|
||||
assert inp.grad is None
|
||||
torch.testing.assert_close(traced_graph_out, f(inp))
|
||||
|
||||
def test_mode_tracing_factory_function(self):
|
||||
def f(x):
|
||||
|
@ -15,7 +15,7 @@ from contextlib import contextmanager
|
||||
|
||||
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
|
||||
|
||||
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx"]
|
||||
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
|
||||
aten = torch.ops.aten
|
||||
|
||||
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
|
||||
@ -40,6 +40,11 @@ def decompose(decomposition_table):
|
||||
finally:
|
||||
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
|
||||
|
||||
# Checks whether we try to convert the tensor into a scalar
|
||||
IS_STRICT = True
|
||||
def enable_strict(val):
|
||||
global IS_STRICT
|
||||
IS_STRICT = val
|
||||
|
||||
def wrap_output(real_out, proxy_out):
|
||||
def wrap_with_proxy(e, proxy):
|
||||
@ -68,7 +73,8 @@ def proxy_call(func_overload, args, kwargs=None):
|
||||
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
|
||||
if func_overload == aten._local_scalar_dense.default:
|
||||
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
|
||||
"It's likely that this is caused by data-dependent control flow or similar.")
|
||||
"It's likely that this is caused by data-dependent control flow or similar."
|
||||
"Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")
|
||||
|
||||
def unwrap_proxy(e):
|
||||
return e.proxy if isinstance(e, ProxyTensor) else e
|
||||
@ -92,18 +98,24 @@ class ProxyTensor(torch.Tensor):
|
||||
proxy: fx.Proxy
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, proxy):
|
||||
def __new__(cls, elem, proxy, *, requires_grad=None):
|
||||
# Hack to deal with super().__new__ not working for sparse tensors
|
||||
if elem.is_sparse:
|
||||
proxy.node.meta['tensor_meta'] = {}
|
||||
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
if elem.is_sparse or requires_grad is not None:
|
||||
r = torch.Tensor._make_subclass(cls, elem, requires_grad)
|
||||
else:
|
||||
r = super().__new__(cls, elem) # type: ignore[call-arg]
|
||||
|
||||
if elem.is_sparse:
|
||||
proxy.node.meta['tensor_meta'] = {}
|
||||
else:
|
||||
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
|
||||
r.proxy = proxy # type: ignore[attr-defined]
|
||||
|
||||
return r
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self.clone()
|
||||
|
||||
def __repr__(self):
|
||||
with no_dispatch():
|
||||
return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
|
||||
@ -172,7 +184,7 @@ def wrap_key(f, inps):
|
||||
for idx, arg in enumerate(flat_args):
|
||||
if isinstance(flat_inps[idx], torch.Tensor):
|
||||
with no_dispatch():
|
||||
flat_args[idx] = ProxyTensor(flat_inps[idx], arg)
|
||||
flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf)
|
||||
else:
|
||||
flat_args[idx] = flat_inps[idx]
|
||||
|
||||
|
Reference in New Issue
Block a user