Files
pytorch/test/test_proxy_tensor.py

2155 lines
82 KiB
Python

# Owner(s): ["module: ProxyTensor"]
# ruff: noqa: F841
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch._dynamo
import unittest
import warnings
import operator
from collections.abc import Iterable
from torch.nn.utils import stateless
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch._decomp import decomposition_table
from torch.fx.experimental.symbolic_shapes import (
eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
guard_int, GuardOnDataDependentSymNode
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.common_device_type import ops
import torch.testing._internal.optests as optests
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch import nn
import torch._functorch.config
import re
import functools
import itertools
aten = torch.ops.aten
HAS_CUDA = torch.cuda.is_available()
def strip_end(s, suffix):
if suffix and s.endswith(suffix):
return s[:-len(suffix)]
else:
return s
def show_guards(gm):
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
return "\n".join(
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
)
def process_failures():
"""
Takes file containing failures like
FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950
and processes them into a list of opinfo xfails
"""
f = open('pytest_failures')
failures = f.readlines()
failures = [i.strip() for i in failures]
def process_failure_string(s, matcher):
out = re.search(matcher, s)
return out.groups()
SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
def create_normalized_name(op):
if op.variant_test_name == '':
s = op.name
else:
s = f"{op.name}.{op.variant_test_name}"
return s.replace('.', '_')
remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
print("symbolic_tensor_failures = {")
for failure, reason in failures:
print(f" xfail{remap_opinfo[failure]}, # {reason}")
print("}")
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
def _create_new_input(x):
if not isinstance(x, torch.Tensor):
return x
if x.dtype != torch.float:
return x + 1
if x.is_leaf:
return torch.rand_like(x, requires_grad=x.requires_grad)
else:
return torch.rand_like(x)
"""
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
"""
class UnwrapTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
dtype=tensor.dtype,
device=tensor.device,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
)
r._tensor = tensor
return r
def __repr__(self):
# TODO: consider all_gather the local tensors for better debugging
return f"UnwrapTensor({self._tensor})"
__torch_function__ = _disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
ret = e
if isinstance(e, UnwrapTensor):
ret = e._tensor.cos()
return ret
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
return func(*args, **kwargs)
class TestGenericProxyTensor(TestCase):
# WARNING: if any of your inputs are index tensors, DO NOT use this
# function
def _test(self, f, inps):
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
new_inps = tree_map(_create_new_input, inps)
r1 = fx_f(*new_inps)
r2 = f(*new_inps)
self.assertEqual(r1, r2)
def test_pre_dispatch_mode_stack(self):
def f(a):
b = torch.ones(4, 4)
return torch.matmul(a, b)
# We expect to see matmul in the trace - it should NOT be decomposed into mm.
# Also, torch.ones() doesn't show up in the trace.
# This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
# so our mode never sees it - it goes directly to the BackendSelect key.
inp = torch.ones(4, 4)
# Test that make_fx(pre_dispatch=True) clears caches properly.
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
out1 = f(inp)
fx_g = make_fx(f, pre_dispatch=True)(inp)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
return matmul""")
def test_pre_dispatch_linear(self):
def f(a, b, c):
return torch.nn.functional.linear(a, b, c)
a = torch.ones(4, 4)
b = torch.ones(4, 4)
c = torch.ones(4)
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
out1 = f(a, b, c)
out2 = fx_g(a, b, c)
self.assertEqual(out1, out2)
def test_pre_dispatch_no_grad(self):
def f(a):
b = a.sin()
torch.set_grad_enabled(False)
c = b.cos()
torch.set_grad_enabled(True)
return b + c.sin()
a1 = torch.randn(4, requires_grad=True)
a2 = a1.detach().clone().requires_grad_(True)
a_tmp = a1.detach().clone().requires_grad_(True)
fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
out1 = f(a1)
out2 = fx_g(a2)
self.assertEqual(out1, out2)
out1.sum().backward()
out2.sum().backward()
self.assertEqual(a1.grad, a2.grad)
def test_make_fx_simple(self):
def f(x):
return torch.sin(x)
self._test(f, (torch.randn(3),))
def test_scalar_device(self, device='cpu'):
def f(a, b):
return a + b
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
def test_isolated_graphmodule(self):
def is_any_sum(gm):
return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
def is_any_digamma(gm):
return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
def is_any_sigmoid(gm):
return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
def inner(x):
return torch.sum(x)
def f(x):
gm = get_isolated_graphmodule(inner, (x,), {})
self.assertTrue(is_any_sum(gm))
return x + torch.randn(x.shape)
# get_isolated_graphmodule uses make_fx internally that shouldn't be traced
# by the outer make_fx call
traced = make_fx(f)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
# When factory functions are used, they should not be traced
# by the outer make_fx call
def inner_with_factory():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((10, 10), val).sum()
def f1(x):
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2(x):
gm = get_isolated_graphmodule(f1, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify nested make_fx calls don't make factory functions to be leaked
# into the outer graph. Verify that `make_fx`` itself does not leak its execution.
def f2(x):
gm = make_fx(f1)(x)
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify that the `forward`` function of a graph module produced as a
# side effect of an interior `make_fx` is still traced
def f3(x):
gm = make_fx(f1)(x)
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
# `gm.forward`` is still traced
return torch.digamma(gm(x))
traced = make_fx(f3)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertTrue(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify interaction with non-ProxyTensor modes
from torch.testing._internal.logging_tensor import LoggingTensorMode
def f1_logging(x):
with LoggingTensorMode():
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2_logging(x):
with LoggingTensorMode(), LoggingTensorMode():
gm = get_isolated_graphmodule(f1_logging, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2_logging)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify interaction with another tensor subclass
# This case currently doesn't work and should raise an error
# See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
from torch.testing._internal.logging_tensor import LoggingTensor
def f1_logging_tensor(x):
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2_logging_tensor(x):
x = LoggingTensor(x)
gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2_logging_tensor)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
self.assertTrue(is_any_digamma(traced))
# See https://github.com/pytorch/pytorch/issues/97541
def test_empty_like_doesnt_burn_in_defaults(self):
def f(x):
return torch.empty_like(x)
out = make_fx(f)(torch.randn(3))
self.assertExpectedInline(out.code.strip(), """\
def forward(self, x_1):
empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None
return empty_like""")
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
def f(x):
y = x.new_zeros(x.size())
y.copy_(x)
return y
def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
# When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
# to still be (re-entrantly) enabled, so that the `torch.zero()` call
# returns a ProxyTensor.
out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
self.assertExpectedInline(out.code, """\
def forward(self, x_1):
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
return copy_
""")
def test_make_fx_reentrant_dispatch(self):
def f(x):
return torch.ops.aten.norm.Scalar(x, 2.0)
def norm_decomp(x, p=2.0):
if p != 2.0:
raise RuntimeError("can't handle with p != 2")
return torch.sqrt(torch.sum(torch.square(x)))
decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
for n in traced.graph.nodes:
self.assertTrue("square" not in str(n.target))
self.assertTrue("norm" not in str(n.target))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self):
mod = torchvision.models.resnet18()
# An old version of this test called the module directly. This works
# for tracing_mode == "real", but for fake tensors, we also have to
# ensure that the parameters and buffers get wrapped in fake tensors
# because free fake tensors are not supported. Fortunately functional_call
# does precisely this for us.
def f(x, params, buffers):
for p in params.values():
p.grad = None
loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
# I could have done this with the functional API, but there is
# plenty of exercising this; I want to show mutating API still
# works
loss.backward()
return [p.grad for p in params.values()]
inp = torch.randn(3, 3, 250, 250)
self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
def test_varargs(self):
def f(*args):
return sum(args)
self._test(f, [torch.randn(2), torch.randn(2)])
def test_proxy_tensor(self):
def f_grad(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
def f_backward(x):
val = x.cos().cos().sum()
val.backward()
return x.grad
for f in [f_grad, f_backward]:
self._test(f, [torch.randn(3, requires_grad=True)])
def test_pickle_issue89626(self):
import pickle
x = torch.randn(2)
make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
pickle.dumps(x)
def test_inplace_metadata(self):
def f(x):
x = x.clone()
x.unsqueeze_(-1)
assert x.shape[-1] == 1
return x
self._test(f, [torch.randn(5)])
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
# default behavior should trace factory functions
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(
any(
node.target == aten.randn.default
for node in traced.graph.nodes
)
)
def test_pre_dispatch_functionalization(self):
def f(x):
a = FunctionalTensorMode(pre_dispatch=True, export=True)
with a:
x_unwrapped = FunctionalTensor.to_functional(x)
y = torch.matmul(x_unwrapped, x_unwrapped)
y = y + x_unwrapped
y.mul_(5)
y_unwrapped = torch._from_functional_tensor(y.elem)
return y_unwrapped
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
inp = torch.randn(4, 4)
gm = make_fx(f, pre_dispatch=True)(inp)
# TODO actually not decompose
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
matmul = torch.ops.aten.matmul.default(x_1, x_1)
add = torch.ops.aten.add.Tensor(matmul, x_1); matmul = x_1 = None
mul = torch.ops.aten.mul.Tensor(add, 5); add = None
return mul""")
def test_pre_dispatch_functionalization_view_op(self):
def f(x):
a = FunctionalTensorMode(pre_dispatch=True, export=True)
with a:
x_unwrapped = FunctionalTensor.to_functional(x)
y = torch.matmul(x_unwrapped, x_unwrapped)
x_unwrapped = x_unwrapped.transpose(1, 0)
y = y + x_unwrapped
y = y.view(2, 8)
y_unwrapped = torch._from_functional_tensor(y.elem)
return y_unwrapped
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
inp = torch.randn(4, 4)
gm = make_fx(f, pre_dispatch=True)(inp)
# TODO actually not decompose
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
matmul = torch.ops.aten.matmul.default(x_1, x_1)
transpose = torch.ops.aten.transpose.int(x_1, 1, 0); x_1 = None
add = torch.ops.aten.add.Tensor(matmul, transpose); matmul = transpose = None
view = torch.ops.aten.view.default(add, [2, 8]); add = None
return view""")
def test_val_metadata_mutation(self):
def f(x):
y = x.clone()
y.unsqueeze_(0)
return y
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
self.assertEqual([
tuple(node.meta['val'].shape)
for node in traced.graph.nodes
if 'val' in node.meta
], [(3,), (3,), (1, 3)])
def test_make_fx_overloads(self):
def f(x):
return x.cos() + torch.randn(x.shape)
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function'))
def test_tensor_constants(self):
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
self._test(f, [])
def test_allclose(self):
def f(a, b):
return torch.allclose(a, b)
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)(
torch.zeros(3), torch.zeros(3)
)
if self.tracing_mode != "real":
self.assertRaises(DataDependentOutputException, test_f)
else:
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
def test_constant_proxy_tensor_mut(self):
def f():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((100, 100), val)
g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
# In case we mutated shared state in the g graph!
self.assertEqual(g(), f())
def test_constant_unbind(self):
def f():
val = torch.tensor([2])
r, = torch.unbind(val, 0)
return r.item()
g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
def test_constant_blowup(self):
def f():
val = torch.tensor([2])
blowup = val.repeat(1000)
return bool(blowup.sum().item() == 2)
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)()
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
def test_constant_random(self):
def f():
val = torch.tensor([2.0])
val.normal_()
return bool(val.item() == 2.1)
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)()
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
x = torch.rand((4, 4))
fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
found_silu = False
for n in fx_module.graph.nodes:
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
found_silu = True
self.assertTrue(found_silu)
new_graph = torch.fx.Graph()
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
DecompositionInterpreter(
fx_module,
new_graph=new_graph,
decomposition_table=silu_decomp_table,
).run(x)
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
for n in decomposed_module.graph.nodes:
self.assertTrue(n.target != torch.ops.aten.silu)
self.assertTrue(n.target != torch.ops.aten.silu.default)
self.assertEqual(fx_module(x), decomposed_module(x))
def test_make_fx_model_fwd_bwd(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
def f(x, params):
out = torch.func.functional_call(model, params, x).sum()
out.backward()
return list(params.values())
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
# fx may change the order of parameters in list, so using set() to compare
self.assertTrue(
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
or
torch.allclose(fx_f(input, params)[0], f(input, params)[1])
)
self.assertTrue(
torch.allclose(fx_f(input, params)[1], f(input, params)[0])
or
torch.allclose(fx_f(input, params)[1], f(input, params)[1])
)
def test_make_fx_model_double_param(self):
class Emformer(torch.nn.Module):
def __init__(
self,
input_dim: int = 256,
) -> None:
super().__init__()
self.layer_norm = torch.nn.LayerNorm(input_dim)
def forward(mod_self, x): # noqa: B902
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
y = mod_self.layer_norm(x)
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
z = mod_self.layer_norm(y)
return z
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
self.assertEqual(len(ops), 2)
def test_make_fx_model_fwd_bwd_wgtupdate(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
def f(args, params, buffers):
for p in params.values():
p.grad = None
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}
out = torch.func.functional_call(model, params_and_buffers, args)
out.sum().backward()
return [p - 1e-4 * p.grad for p in params.values()]
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
# fx may change the order of parameters in list, so using set() to compare
# also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
self.assertTrue(
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
or
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
)
self.assertTrue(
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
or
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
)
def test_trace_subclasses(self):
def f1(x):
x = UnwrapTensor(x)
y = x * 2
return y
def f2(x):
wrapped = UnwrapTensor(x)
y = x * wrapped
return y
inp = [torch.randn(5)]
self._test(f1, inp)
self._test(f2, inp)
def test_partial_decomp(self):
def f(a, b, c):
x = torch.addmm(a, b, c)
y = torch.addmm(a, b, c, beta=2, alpha=1)
return x + y
inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
fx_g = make_fx(f)(*inps)
def addmm(a, b, c, beta=1, alpha=1):
if beta == 1 and alpha == 1:
return NotImplemented
return beta * a + alpha * (b @ c)
decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
def test_decomp_of_capture(self):
val = torch.randn(5)
def f(x):
return x.t() + val.t()
def nop(x):
return x.cos()
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_amp_cache(self):
layer = torch.nn.Conv2d(3, 3, 3).cuda()
def f(x, w):
return torch.nn.functional.conv2d(x, w, stride=layer.stride)
inp = torch.randn(4, 3, 10, 10, device='cuda')
with torch.autocast('cuda'):
out_graph = make_fx(f)(inp, layer.weight).graph
out_graph2 = make_fx(f)(inp, layer.weight).graph
self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
for a, b in zip(out_graph.nodes, out_graph2.nodes):
self.assertEqual(a.op, b.op)
def test_strides(self):
def f(x):
self.assertTrue(x.is_contiguous())
self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
x = x.permute(0, 3, 1, 2)
self.assertFalse(x.is_contiguous())
self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
return x
make_fx(f)(torch.randn(2, 3, 4, 5))
def f(x):
self.assertTrue(x.is_contiguous())
y = x[:, 1]
self.assertFalse(y.is_contiguous())
y = x[:, ::2]
self.assertFalse(y.is_contiguous())
return x.cos()
make_fx(f)(torch.randn(2, 3, 4, 5))
def test_pr_86917(self):
# Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344
def f(a, b):
return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
class TestGenericProxyTensorReal(TestGenericProxyTensor):
tracing_mode = "real"
class TestGenericProxyTensorFake(TestGenericProxyTensor):
tracing_mode = "fake"
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
tracing_mode = "symbolic"
del TestGenericProxyTensor
class TestRealProxyTensor(TestCase):
def test_error_on_data_dependent_ops(self):
def f():
x = torch.randn([])
y = torch.randn([])
assert torch.allclose(x * y, y * x)
z = float(x)
z2 = float(y)
# Smoke tests
make_fx(f, _error_on_data_dependent_ops=False)()
make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
class TestFakeProxyTensor(TestCase):
def test_issue82547(self):
x = nn.Parameter(torch.randn(3, 3))
def f():
return torch.ops.aten.t.default(x)
self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
class A(torch.Tensor):
pass
x = A(torch.randn(3, 3))
self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
def test_use_fake_and_tensor(self):
def f(x, y):
z = torch.tensor([2.0, 3.0])
return x + y + z
g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
x, y = torch.randn(2), torch.randn(2)
self.assertEqual(g(x, y), f(x, y))
def test_free_fake(self):
def f(x):
return torch.add(x, y)
with FakeTensorMode() as fake_mode:
y = torch.randn(2)
make_fx(f, tracing_mode="real")(torch.randn(2))
def test_fused_adam(self):
# See https://github.com/pytorch/pytorch/issues/99356
params = [torch.randn(10, 10) for _ in range(10)]
grads = [torch.randn(10, 10) for _ in range(10)]
exp_avgs = [torch.randn(10, 10) for _ in range(10)]
exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
state_steps = [torch.tensor(0) for _ in range(10)]
def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
(new_params, _, _, _, _) = aten._fused_adam.default(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr=0.1,
beta1=0.9,
beta2=0.999,
weight_decay=0.01,
eps=1e-8,
amsgrad=False,
maximize=False,
)
for p, new_p in zip(params, new_params):
p.copy_(new_p)
return params
gm = make_fx(fused_adam, tracing_mode='fake')(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
)
ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
for n in gm.graph.nodes:
if n.op == "call_function" and n.target in ensure_ops_have_val:
self.assertIn('val', n.meta)
def test_alias(self):
def f(x):
return torch.ops.aten.alias(x)
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
# NB: this should not have a detach call
self.assertExpectedInline(r, """\
def forward(self, x_1):
alias = torch.ops.aten.alias.default(x_1); x_1 = None
return alias""")
def test_meta(self):
def f(x):
a = x.cos()
b = torch.var_mean(a, dim=0)
c = b * 2
return c
out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
for n in out.graph.nodes:
if n.op == 'output':
continue
self.assertTrue('val' in n.meta)
def test_fake_tensor_mode(self):
def f(a):
d = a.cos()
return d
from torch._guards import detect_fake_mode
existing_fake_mode = FakeTensorMode()
with existing_fake_mode:
out = make_fx(f, tracing_mode="real")(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]))
fake_mode = detect_fake_mode([node.meta.get('val', None) for node in out.graph.nodes])
self.assertEqual(fake_mode, existing_fake_mode)
def _get_node(fx_g, cond):
for n in fx_g.graph.nodes:
if cond(n):
return n
raise AssertionError
def _get_free_symbols(shape_env):
vars = tuple(shape_env.var_to_val.keys())
return len([var for var in vars if var not in shape_env.replacements])
def _trace(f, *args):
inps = [torch.randn(arg) for arg in args]
return make_fx(f, tracing_mode="symbolic")(*inps)
# TODO: Need to test the guards themselves specifically as well
class TestSymbolicTracing(TestCase):
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
"""
Tests fn traced with trace_inputs against test_inputs
Also returns shape env
"""
trace_inputs = [torch.randn(shape) for shape in trace_inputs]
traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
for input in test_inputs:
input = [torch.randn(shape) for shape in input]
rx, ry = traced_f(*input), fn(*input)
if assert_eq:
self.assertEqual(rx, ry)
return traced_f
def test_debug_interpreter(self):
import torch.library
from torch.library import Library
foo = Library("foo", "DEF") # noqa: TOR901
foo.define("foo(Tensor self) -> Tensor")
# Operator where meta and cpu disagree on strides
@torch.library.impl(foo, "foo", "CPU")
def foo_cpu(x):
return x.clone().T
@torch.library.impl(foo, "foo", "Meta")
def foo_meta(x):
return x.clone()
def f(x):
return torch.ops.foo.foo.default(x)
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
from torch._functorch.compilers import DebugInterpreter
interp = DebugInterpreter(gm)
# input mismatch is caught (indicates guard problem)
self.assertRaisesRegex(
AssertionError, r"3 != 1",
lambda: interp.run(torch.randn(3, 3).T),
)
# Catch the incorrect meta
self.assertRaisesRegex(
AssertionError, r"\(3, 1\) != \(1, 3\)",
lambda: interp.run(torch.randn(3, 3))
)
def test_int_input(self):
def f(x, y):
return x.view(y)
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
view = torch.ops.aten.view.default(x_1, [y_1]); x_1 = y_1 = None
return view""")
def test_resize_from_zero(self):
def f(x, y):
x.resize_(y.size(0))
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = resize_ = None
return None""")
def test_broadcast_shapes(self):
def f(x, y):
return torch.functional.broadcast_shapes(x.size(), y.size()[0])
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
return (sym_size_int, sym_size_int_1)""")
def test_deduped_shape(self):
def f(s0, s1, x, y):
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
x = torch.empty(3, 1)
y = torch.empty(5)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
shape_env = ShapeEnv()
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
x = fake_mode.from_tensor(x)
y = fake_mode.from_tensor(y)
r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip()
self.assertExpectedInline(r, """\
def forward(self, s0_1, s1_1, x_1, y_1):
empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
return ((s0_1, s1_1), empty)""")
def test_non_deduped_shape(self):
def f(x, y):
return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
x = torch.empty(3, 1)
y = torch.empty(5)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
shape_env = ShapeEnv()
with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
x = fake_mode.from_tensor(x)
y = fake_mode.from_tensor(y)
r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None
empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
return ((sym_size_int, sym_size_int_1), empty)""")
def test_unary(self):
def f(x):
assert x.shape[0] < 20
return x.cos()
test_inputs = []
test_inputs.append([(2, 5)])
test_inputs.append([(6, 8)])
gm = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}")
self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
def test_repeat_interleave(self):
def f(src_tokens, beam_size_src):
return src_tokens.repeat_interleave(beam_size_src.size(0), 0)
prompt_size = 64
vocab_size = 64
batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5))
self.assertEqual(len(gm.shape_env.guards), 0)
def test_non_symint_size_spec(self):
# this isn't really a proxy tensor test, but it's the most convenient
# way to get a fake tensor with symbolic sizes
def f(x):
torch._C._non_sym_sizes(x)
return x + 1
x = torch.randn(2, 3)
make_fx(f, tracing_mode="symbolic")(x)
# https://github.com/pytorch/pytorch/issues/108195
def test_symbolic_repeat_interleave(self):
def f(y, x):
return y.repeat_interleave(x, dim=1)
y = torch.tensor([[1, 2], [3, 4]])
x = torch.tensor([2, 3])
r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip()
self.assertExpectedInline(r, """\
def forward(self, y_1, x_1):
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1); x_1 = None
index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave); y_1 = repeat_interleave = None
return index_select""")
def test_mod_gcd_unbacked(self):
def f(_a, _b, _stride):
a = _a.item()
b = _b.item()
stride = _stride.item()
torch._check_is_size(a)
torch._check_is_size(b)
torch._check_is_size(stride)
ta = torch.randn(a * stride)
tb = torch.randn(b * stride)
r = torch.cat([ta, tb])
return r.view(a + b, stride)
_a = torch.tensor(30)
_b = torch.tensor(20)
_stride = torch.tensor(10)
r = str(make_fx(f, tracing_mode="symbolic")(_a, _b, _stride).code).strip()
self.assertExpectedInline(r, """\
def forward(self, _a_1, _b_1, _stride_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(_a_1); _a_1 = None
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(_b_1); _b_1 = None
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(_stride_1); _stride_1 = None
mul = _local_scalar_dense * _local_scalar_dense_2
randn = torch.ops.aten.randn.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
mul_1 = _local_scalar_dense_1 * _local_scalar_dense_2
randn_1 = torch.ops.aten.randn.default([mul_1], device = device(type='cpu'), pin_memory = False); mul_1 = None
cat = torch.ops.aten.cat.default([randn, randn_1]); randn = randn_1 = None
add = _local_scalar_dense + _local_scalar_dense_1; _local_scalar_dense = _local_scalar_dense_1 = None
view = torch.ops.aten.view.default(cat, [add, _local_scalar_dense_2]); cat = add = _local_scalar_dense_2 = None
return view""")
def test_cumsum_unbacked(self):
def f(x):
y = x.item()
z = torch.randn((3, y, 3))
return z.cumsum(0)
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, x_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
cumsum = torch.ops.aten.cumsum.default(randn, 0); randn = None
return cumsum""" # noqa: B950
)
def test_repeat_interleave_unbacked_output_size(self):
def f(x, y):
s = x.sum().item()
return y.repeat_interleave(x, dim=0, output_size=s)
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, x_1, y_1):
sum_1 = torch.ops.aten.sum.default(x_1)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None
index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None
return index_select""" # noqa: B950
)
def test_arange_unbacked_output_size(self):
def f(x):
return torch.arange(0, x)
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, x_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return arange""" # noqa: B950
)
def test_adv_index_batch(self):
def f(src_tokens):
bsz, src_len = src_tokens.size()[:2]
start_step = src_tokens.shape[1]
beam_size = 1
generate_size = 64
max_len = src_len + generate_size
tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0)
tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0)
return tokens
prompt_size = 64
vocab_size = 64
batch_size = 4
src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
# Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
# 1 ok)
self.assertEqual(len(gm.shape_env.guards), 0)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_cpu_scalar_cuda(self):
# Extracted from wave2vec2
def f(a, b):
return (a * b) @ b
r = str(
make_fx(f, tracing_mode="symbolic")(
torch.tensor(1.0), torch.randn(2, 2, device='cuda')
).code
).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1, b_1):
mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None
mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None
return mm""")
def test_binary_broadcast(self):
def f(a, b):
c = a * b
return c
test_inputs = []
test_inputs.append([(1, 5), (3, 1)])
test_inputs.append([(1, 4), (4, 1)])
shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
assert len(shape_env.guards) == 0
def test_multiply_shape(self):
def f(a):
return torch.empty(a.shape[0] * 2)
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
mul = sym_size_int * 2; sym_size_int = None
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
return empty""")
def test_item(self):
def f(a):
r = a.item()
return r * a
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None
return mul""")
def test_tensor_symfloat(self):
def f(a):
r = torch.tensor(a.size(0) ** 2.0)
assert r.dtype is torch.float
return r
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2))
r = str(gm.code).strip()
# NB: this specializes, which is fine, the point is to make sure the
# dtype inference is correct
self.assertExpectedInline(r, """\
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
return lift_fresh_copy""")
self.assertEqual(gm._tensor_constant0, torch.tensor(4.0))
def test_item_to_constructor(self):
def f(a):
r = a.item()
return torch.empty(r)
r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, a_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return empty""" # noqa: B950
)
def test_setitem_symint(self):
# from moco
# https://github.com/pytorch/pytorch/issues/101939
def f(x):
x[0] = x.size(0)
return x
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip()
self.assertExpectedInline(
r, """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None
select = torch.ops.aten.select.int(x_1, 0, 0)
copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = copy_ = None
return x_1""" # noqa: B950
)
def test_dynamic_pointwise_scalar(self):
def f(gravity, mask):
gravity[mask, 0] = gravity[mask, 0] * -1
r = str(make_fx(f, tracing_mode="symbolic")(
torch.randn((12, 4)),
torch.randint(0, 2, (12,), dtype=torch.bool)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, gravity_1, mask_1):
select = torch.ops.aten.select.int(gravity_1, 1, 0)
index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
mul = torch.ops.aten.mul.Tensor(index, -1); index = None
select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = index_put_ = None
return None""")
def test_reflect_r_over_x(self):
def reflect_R_over_x(R):
reflect = torch.eye(3, device=R.device)
reflect[0, 0] = -1
return reflect @ R @ reflect
def f(crop_camera, mask):
crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
r = str(make_fx(f, tracing_mode="symbolic")(
torch.randn((12, 3, 3)),
torch.randint(0, 2, (12,), dtype=torch.bool)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, crop_camera_1, mask_1):
index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
_tensor_constant0 = self._tensor_constant0
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
select = torch.ops.aten.select.int(eye, 0, 0)
select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = copy_ = None
sym_size_int = torch.ops.aten.sym_size.int(index, 0)
expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None
sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None
view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None
bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None
mul_4 = sym_size_int * 3
view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None
mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None
_unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None
return None""") # noqa: B950
def test_unbacked_slice(self):
def f(x, m):
x = x[m]
return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
make_fx(f, tracing_mode="symbolic")(
torch.randn((12, 3, 3)),
torch.randint(0, 2, (12,), dtype=torch.bool)
)
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_unbacked_batch_resnet(self):
mod = torchvision.models.resnet18()
def f(x, mask, params, buffers):
for p in itertools.chain([x, mask], params.values(), buffers.values()):
for s in p.shape:
guard_int(s)
x = x[mask]
torch._check(x.shape[0] >= 1)
for p in params.values():
p.grad = None
return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
make_fx(f, tracing_mode="symbolic")(
torch.randn(3, 3, 250, 250),
torch.randint(0, 2, (3,), dtype=torch.bool),
dict(mod.named_parameters()),
dict(mod.named_buffers()),
)
def test_boolean_index(self):
def f(images, handedness, valid):
images = images[valid]
handedness = handedness[valid]
right_hand_mask = handedness == 1
images[right_hand_mask] = images[right_hand_mask].flip(-1)
r = str(make_fx(f, tracing_mode="symbolic")(
torch.randint(0, 256, (512, 1, 96, 96)),
torch.randint(0, 1, (512,)),
torch.randint(0, 2, (512,), dtype=torch.bool)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, images_1, handedness_1, valid_1):
index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
index_2 = torch.ops.aten.index.Tensor(index, [eq])
flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = index_put_ = None
return None""")
def test_neg_shape(self):
def f(a):
return torch.empty(-a.shape[0] + 10)
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None
neg = -sym_size_int; sym_size_int = None
add = neg + 10; neg = None
empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
return empty""")
def test_unbacked_unification(self):
def f(x, y):
z = torch.zeros(x.item())
return z + y
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
add = torch.ops.aten.add.Tensor(zeros, y_1); zeros = y_1 = None
return add""") # noqa: B950
def test_reshape_divisibility_unbacked(self):
def f(x):
i0 = x.item()
r = torch.zeros(i0, 4, 20)
r = r.transpose(2, 1)
return r.reshape(-1, 80)
make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
def test_view_divisibility_unbacked(self):
def f(x):
i0 = x.item()
r = torch.zeros(i0, 192)
return r.view(12, -1, 192)
make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_view_divisibility_unbacked_relatively_prime(self):
# See https://github.com/pytorch/pytorch/issues/123651
def f(x):
i0 = x.item()
torch._check_is_size(i0)
# To trigger the original issue, the max bound has to
# be chosen such that 448 / 447 < 2 (which it is.)
torch._check(i0 <= 448)
return torch.zeros(256 * i0).view(-1, 447)
make_fx(f, tracing_mode="symbolic")(torch.tensor(256 * 447, device="cuda"))
def test_unbacked_unify_guard(self):
def f(x, y):
z = torch.zeros(x.item())
torch._check(z.size(0) == y.size(0)) # refines i0 = s0
if z.size(0) == 4:
return y * 2
else:
return y + 2
r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = zeros = None
add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None
return add""") # noqa: B950
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
@unittest.expectedFailure
def test_unbacked_unify_guard_transitivity(self):
def f(x1, x2, y):
z1 = torch.zeros(x1.item())
z2 = torch.zeros(x2.item())
torch._check(z1.size(0) == z2.size(0)) # refines i0 = i1
torch._check(z2.size(0) == y.size(0)) # refines i0 = s0
if z1.size(0) == 4:
return y * 2
else:
return y + 2
gm = make_fx(f, tracing_mode="symbolic")(
torch.tensor(10, device="cuda"),
torch.tensor(10, device="cuda"),
torch.randn(10, device="cuda")
)
insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
gm.recompile()
r = str(gm.code).strip()
# self.assertExpectedInline(
# r, """""" # noqa: B950
# )
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_unbacked_unify_dependency_violation(self):
def f(x1, x2, x3, y):
z1 = x1.item()
torch._check(z1 // 9 == 1)
z2 = x2.item()
z3 = x3.item()
torch._check(z1 == z2 + z3)
return y * 2
# NB: inputs are done as CUDA to ensure they aren't queried to be
# backed
gm = make_fx(f, tracing_mode="symbolic")(
torch.tensor(10, device="cuda"), torch.tensor(5, device="cuda"),
torch.tensor(5, device="cuda"), torch.randn(1, device="cuda")
)
insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
gm.recompile()
self.assertEqual(gm(
torch.tensor(12, device="cuda"), torch.tensor(6, device="cuda"),
torch.tensor(6, device="cuda"), torch.tensor([1.0], device="cuda")),
torch.tensor([2.0], device="cuda")
)
with self.assertRaises(RuntimeError):
gm(
torch.tensor(20, device="cuda"), torch.tensor(10, device="cuda"),
torch.tensor(10, device="cuda"), torch.tensor([1.0], device="cuda")
)
def test_split_unbacked_sizes(self):
def f(lengths, values):
# tolist not directly supported atm
sizes = [lengths[i].item() for i in range(lengths.size(0))]
for s in sizes:
# TODO(avik): no assertion generated with torch._check_is_size?
torch._constrain_as_size(s)
return torch.split(values, sizes)
r = str(make_fx(f, tracing_mode="symbolic")(
torch.tensor([2, 3, 4]),
torch.randn(9)
).code).strip()
self.assertExpectedInline(r, """\
def forward(self, lengths_1, values_1):
select = torch.ops.aten.select.int(lengths_1, 0, 0)
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
_local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None
_local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size = None
sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1); sym_constrain_range_for_size_1 = None
sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2); sym_constrain_range_for_size_2 = None
split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]; split_with_sizes = None
return (getitem, getitem_1, getitem_2)""") # noqa: B950
def test_invalidate_nonzero(self):
ok = False
def f(a):
nonlocal ok
b = a.clone()
x = b.nonzero()
x1 = b.nonzero()
x2 = b.nonzero()
assert x1.shape[0] == x2.shape[0]
ok = True
b.normal_()
y = b.nonzero()
try:
bool(x1.shape[0] == y.shape[0])
self.fail("didn't raise exception")
except GuardOnDataDependentSymNode:
pass
make_fx(f, tracing_mode="symbolic")(torch.randn(4))
@torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
def test_invalidate_nonzero_propagate_real_tensors(self):
def f(a):
b = a.clone()
x = b.nonzero()
x1 = b.nonzero()
x2 = b.nonzero()
assert x1.shape[0] == x2.shape[0]
b.normal_()
y = b.nonzero()
# Because you're not actually going to generate exactly zero with
# normal_ lol
assert x1.shape[0] == y.shape[0]
make_fx(f, tracing_mode="symbolic")(torch.randn(4))
def test_sqrt_size(self):
def f(a):
return a / a.size(-1) ** 0.5
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
pow_1 = sym_float ** 0.5; sym_float = None
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
return div""")
def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):
class Bar(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + 1
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bar = Bar()
def forward(self, x):
return x + self.bar(x)
gm = make_fx(Foo())(torch.randn(4, 4))
for node in gm.graph.nodes:
self.assertTrue("nn_module_stack" not in node.meta)
foo = Foo()
def functional_call(*args, **kwargs):
with stateless._reparametrize_module(foo, {}):
return foo(*args, **kwargs)
functional_call._orig_mod = foo
gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
found = False
for node in gm_with_stack.graph.nodes:
if "nn_module_stack" in node.meta:
if len(node.meta["nn_module_stack"]) == 1:
self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
found = True
elif len(node.meta["nn_module_stack"]) == 2:
self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
found = True
else:
# there can be at most 2 level
self.assertTrue(False)
self.assertTrue(found)
gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
for node in gm_without_stack.graph.nodes:
self.assertTrue("nn_module_stack" not in node.meta)
def test_symint_to_tensor(self):
def f(a):
return a / a.shape[0]
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None
return div""")
r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
sym_float = torch.sym_float(sym_size_int); sym_size_int = None
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
return div""")
def test_cat(self):
def f(a, b):
val = torch.mul(a, b)
out = torch.cat([val, val])
if out.shape[0] * out.shape[1] > 20:
out = out.cos()
return out
test_inputs = []
test_inputs.append([(1, 5), (6, 1)])
test_inputs.append([(1, 4), (3, 1)])
gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""")
def test_new_empty(self):
def f(a, b):
return a.new_empty(b.shape[0], b.shape[1] * 2)
self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
def test_size_with_tensor(self):
# I think I messed up writing this test case originally, I think
# I'm supposed to hit an error case, but the code here works in both
# eager and tracing
def f(tensor):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
return tensor.new_empty(batch_shape)
a = torch.randn(3, 800, 1199)
f(a)
make_fx(f, tracing_mode="symbolic")(a)
def test_fake_tensor_as_size(self):
def f(x):
r = torch.zeros([x])
return r
fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4))
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, x_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None
zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return zeros""") # noqa: B950
def test_expand(self):
def f(a):
b = torch.mul(a, a)
c = b.expand(a.shape)
return c
self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
def test_metadata(self):
def f(a, b):
d = a.new_empty(a.shape[0] + b.shape[0])
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
def test_metadata_fresh(self):
def f(x):
assert x.shape[0] == 3
return x.cos()
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
self.assertTrue(meta_cos.meta['val'].shape[0] == 3)
# Checks if the input expr has been updated even though the constraint
# happened afterwards
self.assertTrue(meta_inp.meta['val'].shape[0] == 3)
def test_elementwise_meta_with_sym_numbers(self):
def f(x, offset, as_sym_float=False):
x0 = x.size()[0]
if as_sym_float:
x0 = torch.sym_float(x0)
return torch.add(x0, offset)
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
self.assertEqual(meta_add.meta['val'].shape, ())
self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
self.assertEqual(meta_add.meta['val'].shape, ())
self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
self.assertEqual(meta_add.meta['val'].shape, ())
self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
def test_return_symint(self):
def f(x):
return x.shape[0], x.cos(), x.shape[0] / 5
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
def f(x):
return x.shape
self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
def test_rmethod(self):
def f(x):
return x.size(0) + x
self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
def test_mega_guard(self):
def f(a, b):
assert a.shape[0] == b.shape[0] * 2
return a.cos()
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
from torch._dynamo.source import LocalSource
self.assertExpectedInline(
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950
)
self.assertExpectedInline(
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950
)
def test_guard_upperbound_range_refinement(self):
def f(a):
assert a.shape[0] > 5 and a.shape[0] > 12
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""")
def test_guard_lowerbound_range_refinement(self):
def f(a):
assert a.shape[0] < 20 and a.shape[0] < 30
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""")
def test_guard_upperbound_range_refinement_multivariate(self):
def f(a):
assert a.shape[0] > 5 and a.shape[0] > 12
assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
self.assertExpectedInline(show_guards(tensor), """\
L['a'].size()[1] > L['a'].size()[0]
13 <= L['a'].size()[0]
14 <= L['a'].size()[1]""")
def test_guard_lowerbound_range_refinement_multivariate(self):
def f(a):
assert a.shape[0] < 20 and a.shape[0] < 30
assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
return a.cos()
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
self.assertExpectedInline(
show_guards(tensor),
"""\
L['a'].size()[1] < L['a'].size()[0]
L['a'].size()[0] <= 19
L['a'].size()[1] <= 18""")
def test_sym_storage_offset(self):
def f(x, y):
return x + y
inp = (torch.randn(8)[3:], torch.randn(5))
fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
inp = (torch.randn(8)[3:], torch.randn(5))
self.assertEqual(fx_g(*inp), f(*inp))
def _assert_no_guards(self, fx_g, free_symbols):
assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
def test_guards_equal(self):
def f(a, b):
return a * b
# NB: Numbers are carefully chosen to avoid duck shaping from applying
fx_g = _trace(f, (5, 6), (5, 6))
self._assert_no_guards(fx_g, 2)
fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (5, 1), (1, 6))
self._assert_no_guards(fx_g, 2)
def f(a, b, c, d):
a = a + b
cat = torch.cat([c, d])
return a + cat
fx_g = _trace(f, 7, 7, 4, 3)
self._assert_no_guards(fx_g, 2)
def f(a, b, c, d, e):
vals = [a, b, c, d, e]
x = a
for idx in range(len(vals) - 1):
x = torch.cat([x, vals[idx]]) + vals[idx + 1]
return x
fx_g = _trace(f, 2, 4, 8, 16, 32)
self._assert_no_guards(fx_g, 1)
def f(a, b):
a = a.view(b.shape[0])
return a + b.sum()
fx_g = _trace(f, (4, 2), 8)
self._assert_no_guards(fx_g, 2)
fx_g = _trace(f, (4, 2), (8, 5))
self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (2, 3, 4), 24)
self._assert_no_guards(fx_g, 3)
def test_nonidentity_transitive_guards(self):
def f(a, b, c, d, e):
vals = [a, b, c, d, e]
cat_vals = []
for idx in range(len(vals) - 1):
cat_vals.append(torch.cat([vals[idx], vals[idx]]))
final_vals = []
for a, b in reversed(list(zip(cat_vals, vals[1:]))):
final_vals.append(a + b)
return final_vals
fx_g = _trace(f, 2, 4, 8, 16, 32)
self.assertExpectedInline(show_guards(fx_g), """""")
@torch.fx.experimental._config.patch(translation_validation=True)
def test_constant_specialization(self):
def f(t):
assert t.shape[0] == 10
return t
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
self.assertExpectedInline(show_guards(tensor), """""")
make_fx_failures = {
# unknown
xfail('allclose'),
xfail('equal'),
# empty
skip('new_empty'),
skip('empty_like'),
skip('empty'),
skip('empty_permuted'),
# flaky
skip('linalg.lstsq', 'grad_oriented'),
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
skip('linalg.lstsq'), # flaky, probably just a precision issue
# data-dependent control flow
skip('item'),
xfail('cov'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
xfail('quantile'),
xfail('nanquantile'),
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
# proxy tensor doesn't support sparse correctly right now
skip('to_sparse'),
# segfaults
skip('block_diag'),
# AssertionError: Tensor-likes are not close!
skip('empty_strided', '', device_type='cpu'),
}
only_real_tensor_failures = {
xfail('narrow'),
}
only_fake_tensor_failures = {
xfail('narrow'),
}
fake_tensor_failures = set()
symbolic_tensor_failures = {
xfail('combinations', ''),
xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition
xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
}
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
inplace_symbolic_tensor_failures = {
# bugs
xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double
}
out_symbolic_tensor_failures = {
# Cast error details: Unable to cast (...) to Tensor
#
# This happens because the test is set up to call the out variant using the `out` kwarg:
# torch._some_op(arg1, arg2, out=(out1, out2, out3))
#
# However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
# this fails because the op has no python bindings, so it doesn't support the `out` kwarg
# way of calling its out variant.
xfail('_batch_norm_with_update', ''),
xfail('_native_batch_norm_legit', ''),
xfail('angle', ''),
xfail('argmax', ''),
xfail('argmin', ''),
xfail('gather', ''),
xfail('linalg.pinv', ''),
xfail('linalg.pinv', 'hermitian'),
xfail('scatter_add', ''),
xfail('scatter', ''),
xfail('take_along_dim', ''),
# SymIntArrayRef expected to contain only concrete
xfail('randn', ''),
# RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides
xfail('index_reduce', 'prod'),
xfail('index_reduce', 'mean'),
xfail('index_reduce', 'amax'),
xfail('index_reduce', 'amin'),
}
out_symbolic_tensor_segfaults = {
skip('nanmean', ''),
}
out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults)
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(inplace_variant):
@functools.wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False):
fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
# Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
count = 100
if out:
count = 5
for sample_input in itertools.islice(sample_inputs_itr, count):
if inplace and sample_input.broadcasts_input:
continue
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
if out:
expected = fn(*args, **kwargs)
kwargs['out'] = expected
try:
optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual,
randomize_data=True)
except DynamicOutputShapeException:
self.skipTest("Dynamic output shape operation in trace")
def skipIfNameMatches(pattern):
"""
Decorator to skip a test if its name matches the given pattern.
"""
def decorator(test_func):
def wrapper(*args, **kwargs):
if re.match(pattern, test_func.__name__):
raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'")
return test_func(*args, **kwargs)
return wrapper
return decorator
# Auto functionalize shouldn't work with make_fx directly
filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"]
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo")
class TestProxyTensorOpInfo(TestCase):
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures.union(only_real_tensor_failures))
def test_make_fx_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "real")
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive',
make_fx_failures.union(fake_tensor_failures, only_fake_tensor_failures))
def test_make_fx_fake_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "fake")
@ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "symbolic")
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
if not op.get_inplace():
self.skipTest("No inplace variable for this op")
_test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
@ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op):
if not op.supports_out:
self.skipTest("Op doesn't support out")
_test_make_fx_helper(self, device, dtype, op, "symbolic", out=True)
only_for = ("cpu")
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()