# Owner(s): ["module: ProxyTensor"] # ruff: noqa: F841 from torch.testing._internal.common_utils import TestCase, run_tests import torch import torch._dynamo import unittest import warnings import operator from collections.abc import Iterable from torch.nn.utils import stateless from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode from torch._decomp import decomposition_table from torch.fx.experimental.symbolic_shapes import ( eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, guard_int, GuardOnDataDependentSymNode ) from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.hop_db import hop_db from torch.testing._internal.common_device_type import ops import torch.testing._internal.optests as optests from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule from torch.utils._pytree import tree_map from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch import nn import torch._functorch.config import re import functools import itertools aten = torch.ops.aten HAS_CUDA = torch.cuda.is_available() def strip_end(s, suffix): if suffix and s.endswith(suffix): return s[:-len(suffix)] else: return s def show_guards(gm): names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)] return "\n".join( gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None) ) def process_failures(): """ Takes file containing failures like FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950 and processes them into a list of opinfo xfails """ f = open('pytest_failures') failures = f.readlines() failures = [i.strip() for i in failures] def process_failure_string(s, matcher): out = re.search(matcher, s) return out.groups() SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)' failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures] def create_normalized_name(op): if op.variant_test_name == '': s = op.name else: s = f"{op.name}.{op.variant_test_name}" return s.replace('.', '_') remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db} print("symbolic_tensor_failures = {") for failure, reason in failures: print(f" xfail{remap_opinfo[failure]}, # {reason}") print("}") USE_TORCHVISION = False try: import torchvision USE_TORCHVISION = True except ImportError: warnings.warn("Couldn't import torchvision. Some of our tests use it, try " "to install it with commands from pytorch.org, post-fixed with " "`--no-deps` to avoid overwriting the pytorch installation", UserWarning) def _create_new_input(x): if not isinstance(x, torch.Tensor): return x if x.dtype != torch.float: return x + 1 if x.is_leaf: return torch.rand_like(x, requires_grad=x.requires_grad) else: return torch.rand_like(x) """ Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used """ class UnwrapTensor(torch.Tensor): @staticmethod def __new__(cls, tensor: torch.Tensor): r = torch.Tensor._make_wrapper_subclass( cls, tensor.size(), dtype=tensor.dtype, device=tensor.device, layout=tensor.layout, requires_grad=tensor.requires_grad, ) r._tensor = tensor return r def __repr__(self): # TODO: consider all_gather the local tensors for better debugging return f"UnwrapTensor({self._tensor})" __torch_function__ = _disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): ret = e if isinstance(e, UnwrapTensor): ret = e._tensor.cos() return ret args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) return func(*args, **kwargs) class TestGenericProxyTensor(TestCase): # WARNING: if any of your inputs are index tensors, DO NOT use this # function def _test(self, f, inps): fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps) new_inps = tree_map(_create_new_input, inps) r1 = fx_f(*new_inps) r2 = f(*new_inps) self.assertEqual(r1, r2) def test_pre_dispatch_mode_stack(self): def f(a): b = torch.ones(4, 4) return torch.matmul(a, b) # We expect to see matmul in the trace - it should NOT be decomposed into mm. # Also, torch.ones() doesn't show up in the trace. # This is annoying but expected: ones() never dispatches to the Autograd dispatch key, # so our mode never sees it - it goes directly to the BackendSelect key. inp = torch.ones(4, 4) # Test that make_fx(pre_dispatch=True) clears caches properly. from torch._dispatch.python import enable_python_dispatcher with enable_python_dispatcher(): out1 = f(inp) fx_g = make_fx(f, pre_dispatch=True)(inp) self.assertExpectedInline(fx_g.code.strip(), """\ def forward(self, a_1): ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False) matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None return matmul""") def test_pre_dispatch_linear(self): def f(a, b, c): return torch.nn.functional.linear(a, b, c) a = torch.ones(4, 4) b = torch.ones(4, 4) c = torch.ones(4) fx_g = make_fx(f, pre_dispatch=True)(a, b, c) out1 = f(a, b, c) out2 = fx_g(a, b, c) self.assertEqual(out1, out2) def test_pre_dispatch_no_grad(self): def f(a): b = a.sin() torch.set_grad_enabled(False) c = b.cos() torch.set_grad_enabled(True) return b + c.sin() a1 = torch.randn(4, requires_grad=True) a2 = a1.detach().clone().requires_grad_(True) a_tmp = a1.detach().clone().requires_grad_(True) fx_g = make_fx(f, pre_dispatch=True)(a_tmp) out1 = f(a1) out2 = fx_g(a2) self.assertEqual(out1, out2) out1.sum().backward() out2.sum().backward() self.assertEqual(a1.grad, a2.grad) def test_make_fx_simple(self): def f(x): return torch.sin(x) self._test(f, (torch.randn(3),)) def test_scalar_device(self, device='cpu'): def f(a, b): return a + b self._test(f, [torch.randn(3, device=device), torch.tensor(5)]) def test_isolated_graphmodule(self): def is_any_sum(gm): return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes) def is_any_digamma(gm): return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes) def is_any_sigmoid(gm): return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes) def inner(x): return torch.sum(x) def f(x): gm = get_isolated_graphmodule(inner, (x,), {}) self.assertTrue(is_any_sum(gm)) return x + torch.randn(x.shape) # get_isolated_graphmodule uses make_fx internally that shouldn't be traced # by the outer make_fx call traced = make_fx(f)(torch.randn(3)) self.assertFalse(is_any_sum(traced)) # When factory functions are used, they should not be traced # by the outer make_fx call def inner_with_factory(): val = torch.tensor(float(1)) val.add_(2) return torch.full((10, 10), val).sum() def f1(x): gm = get_isolated_graphmodule(inner_with_factory, (), {}) self.assertTrue(is_any_sum(gm)) return torch.sigmoid(x) def f2(x): gm = get_isolated_graphmodule(f1, (x,), {}) self.assertFalse(is_any_sum(gm)) self.assertTrue(is_any_sigmoid(gm)) return torch.digamma(x) traced = make_fx(f2)(torch.randn(3)) self.assertFalse(is_any_sum(traced)) self.assertFalse(is_any_sigmoid(traced)) self.assertTrue(is_any_digamma(traced)) # Verify nested make_fx calls don't make factory functions to be leaked # into the outer graph. Verify that `make_fx`` itself does not leak its execution. def f2(x): gm = make_fx(f1)(x) self.assertFalse(is_any_sum(gm)) self.assertTrue(is_any_sigmoid(gm)) return torch.digamma(x) traced = make_fx(f2)(torch.randn(3)) self.assertFalse(is_any_sum(traced)) self.assertFalse(is_any_sigmoid(traced)) self.assertTrue(is_any_digamma(traced)) # Verify that the `forward`` function of a graph module produced as a # side effect of an interior `make_fx` is still traced def f3(x): gm = make_fx(f1)(x) self.assertFalse(is_any_sum(gm)) self.assertTrue(is_any_sigmoid(gm)) # `gm.forward`` is still traced return torch.digamma(gm(x)) traced = make_fx(f3)(torch.randn(3)) self.assertFalse(is_any_sum(traced)) self.assertTrue(is_any_sigmoid(traced)) self.assertTrue(is_any_digamma(traced)) # Verify interaction with non-ProxyTensor modes from torch.testing._internal.logging_tensor import LoggingTensorMode def f1_logging(x): with LoggingTensorMode(): gm = get_isolated_graphmodule(inner_with_factory, (), {}) self.assertTrue(is_any_sum(gm)) return torch.sigmoid(x) def f2_logging(x): with LoggingTensorMode(), LoggingTensorMode(): gm = get_isolated_graphmodule(f1_logging, (x,), {}) self.assertFalse(is_any_sum(gm)) self.assertTrue(is_any_sigmoid(gm)) return torch.digamma(x) traced = make_fx(f2_logging)(torch.randn(3)) self.assertFalse(is_any_sum(traced)) self.assertFalse(is_any_sigmoid(traced)) self.assertTrue(is_any_digamma(traced)) # Verify interaction with another tensor subclass # This case currently doesn't work and should raise an error # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 from torch.testing._internal.logging_tensor import LoggingTensor def f1_logging_tensor(x): gm = get_isolated_graphmodule(inner_with_factory, (), {}) self.assertTrue(is_any_sum(gm)) return torch.sigmoid(x) def f2_logging_tensor(x): x = LoggingTensor(x) gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {}) self.assertFalse(is_any_sum(gm)) self.assertTrue(is_any_sigmoid(gm)) return torch.digamma(x) traced = make_fx(f2_logging_tensor)(torch.randn(3)) self.assertFalse(is_any_sum(traced)) self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor self.assertTrue(is_any_digamma(traced)) # See https://github.com/pytorch/pytorch/issues/97541 def test_empty_like_doesnt_burn_in_defaults(self): def f(x): return torch.empty_like(x) out = make_fx(f)(torch.randn(3)) self.assertExpectedInline(out.code.strip(), """\ def forward(self, x_1): empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None return empty_like""") def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self): def f(x): y = x.new_zeros(x.size()) y.copy_(x) return y def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None): return torch.zeros(size, dtype=inp.dtype, device=inp.device) factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp} # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode # to still be (re-entrantly) enabled, so that the `torch.zero()` call # returns a ProxyTensor. out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2)) self.assertExpectedInline(out.code, """\ def forward(self, x_1): zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None return copy_ """) def test_make_fx_reentrant_dispatch(self): def f(x): return torch.ops.aten.norm.Scalar(x, 2.0) def norm_decomp(x, p=2.0): if p != 2.0: raise RuntimeError("can't handle with p != 2") return torch.sqrt(torch.sum(torch.square(x))) decomp = {torch.ops.aten.norm.Scalar: norm_decomp} traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3)) for n in traced.graph.nodes: self.assertTrue("square" not in str(n.target)) self.assertTrue("norm" not in str(n.target)) @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_resnet18_backward_trace(self): mod = torchvision.models.resnet18() # An old version of this test called the module directly. This works # for tracing_mode == "real", but for fake tensors, we also have to # ensure that the parameters and buffers get wrapped in fake tensors # because free fake tensors are not supported. Fortunately functional_call # does precisely this for us. def f(x, params, buffers): for p in params.values(): p.grad = None loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() # I could have done this with the functional API, but there is # plenty of exercising this; I want to show mutating API still # works loss.backward() return [p.grad for p in params.values()] inp = torch.randn(3, 3, 250, 250) self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())]) def test_varargs(self): def f(*args): return sum(args) self._test(f, [torch.randn(2), torch.randn(2)]) def test_proxy_tensor(self): def f_grad(x): val = x.cos().cos().sum() return torch.autograd.grad(val, x) def f_backward(x): val = x.cos().cos().sum() val.backward() return x.grad for f in [f_grad, f_backward]: self._test(f, [torch.randn(3, requires_grad=True)]) def test_pickle_issue89626(self): import pickle x = torch.randn(2) make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x) pickle.dumps(x) def test_inplace_metadata(self): def f(x): x = x.clone() x.unsqueeze_(-1) assert x.shape[-1] == 1 return x self._test(f, [torch.randn(5)]) def test_mode_tracing_factory_function(self): def f(x): return x + torch.randn(x.shape) # default behavior should trace factory functions traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) self.assertTrue( any( node.target == aten.randn.default for node in traced.graph.nodes ) ) def test_pre_dispatch_functionalization(self): def f(x): a = FunctionalTensorMode(pre_dispatch=True, export=True) with a: x_unwrapped = FunctionalTensor.to_functional(x) y = torch.matmul(x_unwrapped, x_unwrapped) y = y + x_unwrapped y.mul_(5) y_unwrapped = torch._from_functional_tensor(y.elem) return y_unwrapped from torch._dispatch.python import enable_python_dispatcher with enable_python_dispatcher(): inp = torch.randn(4, 4) gm = make_fx(f, pre_dispatch=True)(inp) # TODO actually not decompose self.assertExpectedInline(gm.code.strip(), """\ def forward(self, x_1): matmul = torch.ops.aten.matmul.default(x_1, x_1) add = torch.ops.aten.add.Tensor(matmul, x_1); matmul = x_1 = None mul = torch.ops.aten.mul.Tensor(add, 5); add = None return mul""") def test_pre_dispatch_functionalization_view_op(self): def f(x): a = FunctionalTensorMode(pre_dispatch=True, export=True) with a: x_unwrapped = FunctionalTensor.to_functional(x) y = torch.matmul(x_unwrapped, x_unwrapped) x_unwrapped = x_unwrapped.transpose(1, 0) y = y + x_unwrapped y = y.view(2, 8) y_unwrapped = torch._from_functional_tensor(y.elem) return y_unwrapped from torch._dispatch.python import enable_python_dispatcher with enable_python_dispatcher(): inp = torch.randn(4, 4) gm = make_fx(f, pre_dispatch=True)(inp) # TODO actually not decompose self.assertExpectedInline(gm.code.strip(), """\ def forward(self, x_1): matmul = torch.ops.aten.matmul.default(x_1, x_1) transpose = torch.ops.aten.transpose.int(x_1, 1, 0); x_1 = None add = torch.ops.aten.add.Tensor(matmul, transpose); matmul = transpose = None view = torch.ops.aten.view.default(add, [2, 8]); add = None return view""") def test_val_metadata_mutation(self): def f(x): y = x.clone() y.unsqueeze_(0) return y traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) self.assertEqual([ tuple(node.meta['val'].shape) for node in traced.graph.nodes if 'val' in node.meta ], [(3,), (3,), (1, 3)]) def test_make_fx_overloads(self): def f(x): return x.cos() + torch.randn(x.shape) traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload) for node in traced.graph.nodes if node.op == 'call_function')) def test_tensor_constants(self): def f(): val = torch.tensor(float('inf')) return torch.full((100, 100), val) self._test(f, []) def test_allclose(self): def f(a, b): return torch.allclose(a, b) def test_f(): make_fx(f, tracing_mode=self.tracing_mode)( torch.zeros(3), torch.zeros(3) ) if self.tracing_mode != "real": self.assertRaises(DataDependentOutputException, test_f) else: self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) def test_constant_proxy_tensor_mut(self): def f(): val = torch.tensor(float(1)) val.add_(2) return torch.full((100, 100), val) g = make_fx(f, tracing_mode=self.tracing_mode)() self.assertEqual(g(), f()) # In case we mutated shared state in the g graph! self.assertEqual(g(), f()) def test_constant_unbind(self): def f(): val = torch.tensor([2]) r, = torch.unbind(val, 0) return r.item() g = make_fx(f, tracing_mode=self.tracing_mode)() self.assertEqual(g(), f()) def test_constant_blowup(self): def f(): val = torch.tensor([2]) blowup = val.repeat(1000) return bool(blowup.sum().item() == 2) def test_f(): make_fx(f, tracing_mode=self.tracing_mode)() self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) def test_constant_random(self): def f(): val = torch.tensor([2.0]) val.normal_() return bool(val.item() == 2.1) def test_f(): make_fx(f, tracing_mode=self.tracing_mode)() self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) def test_decomposition_interpreter(self): def fn(x): return torch.nn.functional.silu(x) x = torch.rand((4, 4)) fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x) found_silu = False for n in fx_module.graph.nodes: if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: found_silu = True self.assertTrue(found_silu) new_graph = torch.fx.Graph() silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} DecompositionInterpreter( fx_module, new_graph=new_graph, decomposition_table=silu_decomp_table, ).run(x) decomposed_module = torch.fx.GraphModule(fx_module, new_graph) for n in decomposed_module.graph.nodes: self.assertTrue(n.target != torch.ops.aten.silu) self.assertTrue(n.target != torch.ops.aten.silu.default) self.assertEqual(fx_module(x), decomposed_module(x)) def test_make_fx_model_fwd_bwd(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): return self.linear(x).relu() model = Foo() def f(x, params): out = torch.func.functional_call(model, params, x).sum() out.backward() return list(params.values()) input = torch.randn(3, 5, requires_grad=True) params = dict(model.named_parameters()) fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params) # fx may change the order of parameters in list, so using set() to compare self.assertTrue( torch.allclose(fx_f(input, params)[0], f(input, params)[0]) or torch.allclose(fx_f(input, params)[0], f(input, params)[1]) ) self.assertTrue( torch.allclose(fx_f(input, params)[1], f(input, params)[0]) or torch.allclose(fx_f(input, params)[1], f(input, params)[1]) ) def test_make_fx_model_double_param(self): class Emformer(torch.nn.Module): def __init__( self, input_dim: int = 256, ) -> None: super().__init__() self.layer_norm = torch.nn.LayerNorm(input_dim) def forward(mod_self, x): # noqa: B902 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) y = mod_self.layer_norm(x) self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) z = mod_self.layer_norm(y) return z gm = make_fx(Emformer())(torch.randn(16, 1, 256)) ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'} self.assertEqual(len(ops), 2) def test_make_fx_model_fwd_bwd_wgtupdate(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): return self.linear(x).relu() model = Foo() def f(args, params, buffers): for p in params.values(): p.grad = None if not isinstance(args, Iterable): args = [args] params_and_buffers = {**params, **buffers} out = torch.func.functional_call(model, params_and_buffers, args) out.sum().backward() return [p - 1e-4 * p.grad for p in params.values()] input = torch.randn(3, 5, requires_grad=True) params = dict(model.named_parameters()) buffers = dict(model.named_buffers()) fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers) # fx may change the order of parameters in list, so using set() to compare # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 self.assertTrue( torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) or torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) ) self.assertTrue( torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) or torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) ) def test_trace_subclasses(self): def f1(x): x = UnwrapTensor(x) y = x * 2 return y def f2(x): wrapped = UnwrapTensor(x) y = x * wrapped return y inp = [torch.randn(5)] self._test(f1, inp) self._test(f2, inp) def test_partial_decomp(self): def f(a, b, c): x = torch.addmm(a, b, c) y = torch.addmm(a, b, c, beta=2, alpha=1) return x + y inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)] fx_g = make_fx(f)(*inps) def addmm(a, b, c, beta=1, alpha=1): if beta == 1 and alpha == 1: return NotImplemented return beta * a + alpha * (b @ c) decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps) self.assertEqual(fx_g(*inps), decomposed_fx(*inps)) self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2) self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1) def test_decomp_of_capture(self): val = torch.randn(5) def f(x): return x.t() + val.t() def nop(x): return x.cos() traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5)) self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_amp_cache(self): layer = torch.nn.Conv2d(3, 3, 3).cuda() def f(x, w): return torch.nn.functional.conv2d(x, w, stride=layer.stride) inp = torch.randn(4, 3, 10, 10, device='cuda') with torch.autocast('cuda'): out_graph = make_fx(f)(inp, layer.weight).graph out_graph2 = make_fx(f)(inp, layer.weight).graph self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes)) for a, b in zip(out_graph.nodes, out_graph2.nodes): self.assertEqual(a.op, b.op) def test_strides(self): def f(x): self.assertTrue(x.is_contiguous()) self.assertFalse(x.is_contiguous(memory_format=torch.channels_last)) x = x.permute(0, 3, 1, 2) self.assertFalse(x.is_contiguous()) self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) return x make_fx(f)(torch.randn(2, 3, 4, 5)) def f(x): self.assertTrue(x.is_contiguous()) y = x[:, 1] self.assertFalse(y.is_contiguous()) y = x[:, ::2] self.assertFalse(y.is_contiguous()) return x.cos() make_fx(f)(torch.randn(2, 3, 4, 5)) def test_pr_86917(self): # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344 def f(a, b): return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10) self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)]) class TestGenericProxyTensorReal(TestGenericProxyTensor): tracing_mode = "real" class TestGenericProxyTensorFake(TestGenericProxyTensor): tracing_mode = "fake" class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): tracing_mode = "symbolic" del TestGenericProxyTensor class TestRealProxyTensor(TestCase): def test_error_on_data_dependent_ops(self): def f(): x = torch.randn([]) y = torch.randn([]) assert torch.allclose(x * y, y * x) z = float(x) z2 = float(y) # Smoke tests make_fx(f, _error_on_data_dependent_ops=False)() make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)() class TestFakeProxyTensor(TestCase): def test_issue82547(self): x = nn.Parameter(torch.randn(3, 3)) def f(): return torch.ops.aten.t.default(x) self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")()) class A(torch.Tensor): pass x = A(torch.randn(3, 3)) self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")()) def test_use_fake_and_tensor(self): def f(x, y): z = torch.tensor([2.0, 3.0]) return x + y + z g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2)) x, y = torch.randn(2), torch.randn(2) self.assertEqual(g(x, y), f(x, y)) def test_free_fake(self): def f(x): return torch.add(x, y) with FakeTensorMode() as fake_mode: y = torch.randn(2) make_fx(f, tracing_mode="real")(torch.randn(2)) def test_fused_adam(self): # See https://github.com/pytorch/pytorch/issues/99356 params = [torch.randn(10, 10) for _ in range(10)] grads = [torch.randn(10, 10) for _ in range(10)] exp_avgs = [torch.randn(10, 10) for _ in range(10)] exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] state_steps = [torch.tensor(0) for _ in range(10)] def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps): (new_params, _, _, _, _) = aten._fused_adam.default( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr=0.1, beta1=0.9, beta2=0.999, weight_decay=0.01, eps=1e-8, amsgrad=False, maximize=False, ) for p, new_p in zip(params, new_params): p.copy_(new_p) return params gm = make_fx(fused_adam, tracing_mode='fake')( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, ) ensure_ops_have_val = [aten._fused_adam.default, operator.getitem] for n in gm.graph.nodes: if n.op == "call_function" and n.target in ensure_ops_have_val: self.assertIn('val', n.meta) def test_alias(self): def f(x): return torch.ops.aten.alias(x) r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip() # NB: this should not have a detach call self.assertExpectedInline(r, """\ def forward(self, x_1): alias = torch.ops.aten.alias.default(x_1); x_1 = None return alias""") def test_meta(self): def f(x): a = x.cos() b = torch.var_mean(a, dim=0) c = b * 2 return c out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5)) for n in out.graph.nodes: if n.op == 'output': continue self.assertTrue('val' in n.meta) def test_fake_tensor_mode(self): def f(a): d = a.cos() return d from torch._guards import detect_fake_mode existing_fake_mode = FakeTensorMode() with existing_fake_mode: out = make_fx(f, tracing_mode="real")(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])) fake_mode = detect_fake_mode([node.meta.get('val', None) for node in out.graph.nodes]) self.assertEqual(fake_mode, existing_fake_mode) def _get_node(fx_g, cond): for n in fx_g.graph.nodes: if cond(n): return n raise AssertionError def _get_free_symbols(shape_env): vars = tuple(shape_env.var_to_val.keys()) return len([var for var in vars if var not in shape_env.replacements]) def _trace(f, *args): inps = [torch.randn(arg) for arg in args] return make_fx(f, tracing_mode="symbolic")(*inps) # TODO: Need to test the guards themselves specifically as well class TestSymbolicTracing(TestCase): def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): """ Tests fn traced with trace_inputs against test_inputs Also returns shape env """ trace_inputs = [torch.randn(shape) for shape in trace_inputs] traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs) for input in test_inputs: input = [torch.randn(shape) for shape in input] rx, ry = traced_f(*input), fn(*input) if assert_eq: self.assertEqual(rx, ry) return traced_f def test_debug_interpreter(self): import torch.library from torch.library import Library foo = Library("foo", "DEF") # noqa: TOR901 foo.define("foo(Tensor self) -> Tensor") # Operator where meta and cpu disagree on strides @torch.library.impl(foo, "foo", "CPU") def foo_cpu(x): return x.clone().T @torch.library.impl(foo, "foo", "Meta") def foo_meta(x): return x.clone() def f(x): return torch.ops.foo.foo.default(x) gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2)) from torch._functorch.compilers import DebugInterpreter interp = DebugInterpreter(gm) # input mismatch is caught (indicates guard problem) self.assertRaisesRegex( AssertionError, r"3 != 1", lambda: interp.run(torch.randn(3, 3).T), ) # Catch the incorrect meta self.assertRaisesRegex( AssertionError, r"\(3, 1\) != \(1, 3\)", lambda: interp.run(torch.randn(3, 3)) ) def test_int_input(self): def f(x, y): return x.view(y) r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip() self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): view = torch.ops.aten.view.default(x_1, [y_1]); x_1 = y_1 = None return view""") def test_resize_from_zero(self): def f(x, y): x.resize_(y.size(0)) r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip() self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): sym_size_int = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]); x_1 = sym_size_int = resize_ = None return None""") def test_broadcast_shapes(self): def f(x, y): return torch.functional.broadcast_shapes(x.size(), y.size()[0]) r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip() self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None return (sym_size_int, sym_size_int_1)""") def test_deduped_shape(self): def f(s0, s1, x, y): return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0]) x = torch.empty(3, 1) y = torch.empty(5) from torch.fx.experimental.symbolic_shapes import ShapeEnv shape_env = ShapeEnv() with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode: x = fake_mode.from_tensor(x) y = fake_mode.from_tensor(y) r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip() self.assertExpectedInline(r, """\ def forward(self, s0_1, s1_1, x_1, y_1): empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False) return ((s0_1, s1_1), empty)""") def test_non_deduped_shape(self): def f(x, y): return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0]) x = torch.empty(3, 1) y = torch.empty(5) from torch.fx.experimental.symbolic_shapes import ShapeEnv shape_env = ShapeEnv() with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode: x = fake_mode.from_tensor(x) y = fake_mode.from_tensor(y) r = str(make_fx(f, tracing_mode="real")(x, y).code).strip() self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0); y_1 = None empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False) return ((sym_size_int, sym_size_int_1), empty)""") def test_unary(self): def f(x): assert x.shape[0] < 20 return x.cos() test_inputs = [] test_inputs.append([(2, 5)]) test_inputs.append([(6, 8)]) gm = self._test_dynamic(f, [(3, 4)], test_inputs) self.assertTrue(eval_guards(gm, torch.randn(4, 5))) self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s75: 4, s96: 5}") self.assertFalse(eval_guards(gm, torch.randn(25, 5))) self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""") def test_repeat_interleave(self): def f(src_tokens, beam_size_src): return src_tokens.repeat_interleave(beam_size_src.size(0), 0) prompt_size = 64 vocab_size = 64 batch_size = 4 src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5)) self.assertEqual(len(gm.shape_env.guards), 0) def test_non_symint_size_spec(self): # this isn't really a proxy tensor test, but it's the most convenient # way to get a fake tensor with symbolic sizes def f(x): torch._C._non_sym_sizes(x) return x + 1 x = torch.randn(2, 3) make_fx(f, tracing_mode="symbolic")(x) # https://github.com/pytorch/pytorch/issues/108195 def test_symbolic_repeat_interleave(self): def f(y, x): return y.repeat_interleave(x, dim=1) y = torch.tensor([[1, 2], [3, 4]]) x = torch.tensor([2, 3]) r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip() self.assertExpectedInline(r, """\ def forward(self, y_1, x_1): repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1); x_1 = None index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave); y_1 = repeat_interleave = None return index_select""") def test_mod_gcd_unbacked(self): def f(_a, _b, _stride): a = _a.item() b = _b.item() stride = _stride.item() ta = torch.randn(a * stride) tb = torch.randn(b * stride) r = torch.cat([ta, tb]) return r.view(a + b, stride) _a = torch.tensor(30) _b = torch.tensor(20) _stride = torch.tensor(10) r = str(make_fx(f, tracing_mode="symbolic")(_a, _b, _stride).code).strip() self.assertExpectedInline(r, """\ def forward(self, _a_1, _b_1, _stride_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(_a_1); _a_1 = None _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(_b_1); _b_1 = None _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(_stride_1); _stride_1 = None mul = _local_scalar_dense * _local_scalar_dense_2 randn = torch.ops.aten.randn.default([mul], device = device(type='cpu'), pin_memory = False); mul = None mul_1 = _local_scalar_dense_1 * _local_scalar_dense_2 randn_1 = torch.ops.aten.randn.default([mul_1], device = device(type='cpu'), pin_memory = False); mul_1 = None cat = torch.ops.aten.cat.default([randn, randn_1]); randn = randn_1 = None add = _local_scalar_dense + _local_scalar_dense_1; _local_scalar_dense = _local_scalar_dense_1 = None view = torch.ops.aten.view.default(cat, [add, _local_scalar_dense_2]); cat = add = _local_scalar_dense_2 = None return view""") def test_cumsum_unbacked(self): def f(x): y = x.item() z = torch.randn((3, y, 3)) return z.cumsum(0) r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip() self.assertExpectedInline( r, """\ def forward(self, x_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None cumsum = torch.ops.aten.cumsum.default(randn, 0); randn = None return cumsum""" # noqa: B950 ) def test_repeat_interleave_unbacked_output_size(self): def f(x, y): s = x.sum().item() return y.repeat_interleave(x, dim=0, output_size=s) r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip() self.assertExpectedInline( r, """\ def forward(self, x_1, y_1): sum_1 = torch.ops.aten.sum.default(x_1) _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None return index_select""" # noqa: B950 ) def test_arange_unbacked_output_size(self): def f(x): return torch.arange(0, x) r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip() self.assertExpectedInline( r, """\ def forward(self, x_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return arange""" # noqa: B950 ) def test_adv_index_batch(self): def f(src_tokens): bsz, src_len = src_tokens.size()[:2] start_step = src_tokens.shape[1] beam_size = 1 generate_size = 64 max_len = src_len + generate_size tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0) tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0) return tokens prompt_size = 64 vocab_size = 64 batch_size = 4 src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size)) gm = make_fx(f, tracing_mode="symbolic")(src_tokens) # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and # 1 ok) self.assertEqual(len(gm.shape_env.guards), 0) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_cpu_scalar_cuda(self): # Extracted from wave2vec2 def f(a, b): return (a * b) @ b r = str( make_fx(f, tracing_mode="symbolic")( torch.tensor(1.0), torch.randn(2, 2, device='cuda') ).code ).strip() self.assertExpectedInline(r, """\ def forward(self, a_1, b_1): mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None return mm""") def test_binary_broadcast(self): def f(a, b): c = a * b return c test_inputs = [] test_inputs.append([(1, 5), (3, 1)]) test_inputs.append([(1, 4), (4, 1)]) shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env assert len(shape_env.guards) == 0 def test_multiply_shape(self): def f(a): return torch.empty(a.shape[0] * 2) r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None mul = sym_size_int * 2; sym_size_int = None empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None return empty""") def test_item(self): def f(a): r = a.item() return r * a r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1) mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None return mul""") def test_tensor_symfloat(self): def f(a): r = torch.tensor(a.size(0) ** 2.0) assert r.dtype is torch.float return r gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2)) r = str(gm.code).strip() # NB: this specializes, which is fine, the point is to make sure the # dtype inference is correct self.assertExpectedInline(r, """\ def forward(self, a_1): _tensor_constant0 = self._tensor_constant0 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None return lift_fresh_copy""") self.assertEqual(gm._tensor_constant0, torch.tensor(4.0)) def test_item_to_constructor(self): def f(a): r = a.item() return torch.empty(r) r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() self.assertExpectedInline( r, """\ def forward(self, a_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return empty""" # noqa: B950 ) def test_setitem_symint(self): # from moco # https://github.com/pytorch/pytorch/issues/101939 def f(x): x[0] = x.size(0) return x r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip() self.assertExpectedInline( r, """\ def forward(self, x_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None select = torch.ops.aten.select.int(x_1, 0, 0) copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = copy_ = None return x_1""" # noqa: B950 ) def test_dynamic_pointwise_scalar(self): def f(gravity, mask): gravity[mask, 0] = gravity[mask, 0] * -1 r = str(make_fx(f, tracing_mode="symbolic")( torch.randn((12, 4)), torch.randint(0, 2, (12,), dtype=torch.bool) ).code).strip() self.assertExpectedInline(r, """\ def forward(self, gravity_1, mask_1): select = torch.ops.aten.select.int(gravity_1, 1, 0) index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None mul = torch.ops.aten.mul.Tensor(index, -1); index = None select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = index_put_ = None return None""") def test_reflect_r_over_x(self): def reflect_R_over_x(R): reflect = torch.eye(3, device=R.device) reflect[0, 0] = -1 return reflect @ R @ reflect def f(crop_camera, mask): crop_camera[mask] = reflect_R_over_x(crop_camera[mask]) r = str(make_fx(f, tracing_mode="symbolic")( torch.randn((12, 3, 3)), torch.randint(0, 2, (12,), dtype=torch.bool) ).code).strip() self.assertExpectedInline(r, """\ def forward(self, crop_camera_1, mask_1): index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1]) eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False) _tensor_constant0 = self._tensor_constant0 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None select = torch.ops.aten.select.int(eye, 0, 0) select_1 = torch.ops.aten.select.int(select, 0, 0); select = None copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = copy_ = None sym_size_int = torch.ops.aten.sym_size.int(index, 0) expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3]) view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]); expand = None sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1) sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2) expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]); index = None view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None mul_9 = sym_size_int * 3 view_3 = torch.ops.aten.view.default(view_2, [mul_9, 3]); view_2 = mul_9 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None return None""") # noqa: B950 def test_unbacked_slice(self): def f(x, m): x = x[m] return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)] make_fx(f, tracing_mode="symbolic")( torch.randn((12, 3, 3)), torch.randint(0, 2, (12,), dtype=torch.bool) ) @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_unbacked_batch_resnet(self): mod = torchvision.models.resnet18() def f(x, mask, params, buffers): for p in itertools.chain([x, mask], params.values(), buffers.values()): for s in p.shape: guard_int(s) x = x[mask] torch._check(x.shape[0] >= 1) for p in params.values(): p.grad = None return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() make_fx(f, tracing_mode="symbolic")( torch.randn(3, 3, 250, 250), torch.randint(0, 2, (3,), dtype=torch.bool), dict(mod.named_parameters()), dict(mod.named_buffers()), ) def test_boolean_index(self): def f(images, handedness, valid): images = images[valid] handedness = handedness[valid] right_hand_mask = handedness == 1 images[right_hand_mask] = images[right_hand_mask].flip(-1) r = str(make_fx(f, tracing_mode="symbolic")( torch.randint(0, 256, (512, 1, 96, 96)), torch.randint(0, 1, (512,)), torch.randint(0, 2, (512,), dtype=torch.bool) ).code).strip() self.assertExpectedInline(r, """\ def forward(self, images_1, handedness_1, valid_1): index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None index_2 = torch.ops.aten.index.Tensor(index, [eq]) flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = index_put_ = None return None""") def test_neg_shape(self): def f(a): return torch.empty(-a.shape[0] + 10) r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0); a_1 = None neg = -sym_size_int; sym_size_int = None add = neg + 10; neg = None empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None return empty""") def test_unbacked_unification(self): def f(x, y): z = torch.zeros(x.item()) return z + y r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip() self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None add = torch.ops.aten.add.Tensor(zeros, y_1); zeros = y_1 = None return add""") # noqa: B950 def test_reshape_divisibility_unbacked(self): def f(x): i0 = x.item() r = torch.zeros(i0, 4, 20) r = r.transpose(2, 1) return r.reshape(-1, 80) make_fx(f, tracing_mode="symbolic")(torch.tensor(24)) def test_view_divisibility_unbacked(self): def f(x): i0 = x.item() r = torch.zeros(i0, 192) return r.view(12, -1, 192) make_fx(f, tracing_mode="symbolic")(torch.tensor(24)) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_view_divisibility_unbacked_relatively_prime(self): # See https://github.com/pytorch/pytorch/issues/123651 def f(x): i0 = x.item() # To trigger the original issue, the max bound has to # be chosen such that 448 / 447 < 2 (which it is.) torch._check(i0 > 0) torch._check(i0 <= 448) return torch.zeros(256 * i0).view(-1, 447) make_fx(f, tracing_mode="symbolic")(torch.tensor(256 * 447, device="cuda")) def test_unbacked_unify_guard(self): def f(x, y): z = torch.zeros(x.item()) torch._check(z.size(0) == y.size(0)) # refines i0 = s0 if z.size(0) == 4: return y * 2 else: return y + 2 r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip() self.assertExpectedInline(r, """\ def forward(self, x_1, y_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = zeros = None add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None return add""") # noqa: B950 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') @unittest.expectedFailure def test_unbacked_unify_guard_transitivity(self): def f(x1, x2, y): z1 = torch.zeros(x1.item()) z2 = torch.zeros(x2.item()) torch._check(z1.size(0) == z2.size(0)) # refines i0 = i1 torch._check(z2.size(0) == y.size(0)) # refines i0 = s0 if z1.size(0) == 4: return y * 2 else: return y + 2 gm = make_fx(f, tracing_mode="symbolic")( torch.tensor(10, device="cuda"), torch.tensor(10, device="cuda"), torch.randn(10, device="cuda") ) insert_deferred_runtime_asserts(gm, gm.shape_env, "test") gm.recompile() r = str(gm.code).strip() # self.assertExpectedInline( # r, """""" # noqa: B950 # ) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') def test_unbacked_unify_dependency_violation(self): def f(x1, x2, x3, y): z1 = x1.item() torch._check(z1 // 9 == 1) z2 = x2.item() z3 = x3.item() torch._check(z1 == z2 + z3) return y * 2 # NB: inputs are done as CUDA to ensure they aren't queried to be # backed gm = make_fx(f, tracing_mode="symbolic")( torch.tensor(10, device="cuda"), torch.tensor(5, device="cuda"), torch.tensor(5, device="cuda"), torch.randn(1, device="cuda") ) insert_deferred_runtime_asserts(gm, gm.shape_env, "test") gm.recompile() self.assertEqual(gm( torch.tensor(12, device="cuda"), torch.tensor(6, device="cuda"), torch.tensor(6, device="cuda"), torch.tensor([1.0], device="cuda")), torch.tensor([2.0], device="cuda") ) with self.assertRaises(RuntimeError): gm( torch.tensor(20, device="cuda"), torch.tensor(10, device="cuda"), torch.tensor(10, device="cuda"), torch.tensor([1.0], device="cuda") ) def test_split_unbacked_sizes(self): def f(lengths, values): # tolist not directly supported atm sizes = [lengths[i].item() for i in range(lengths.size(0))] return torch.split(values, sizes) r = str(make_fx(f, tracing_mode="symbolic")( torch.tensor([2, 3, 4]), torch.randn(9) ).code).strip() self.assertExpectedInline(r, """\ def forward(self, lengths_1, values_1): select = torch.ops.aten.select.int(lengths_1, 0, 0) _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select); select = None select_1 = torch.ops.aten.select.int(lengths_1, 0, 1) _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None select_2 = torch.ops.aten.select.int(lengths_1, 0, 2); lengths_1 = None _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2); select_2 = None split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]); values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None getitem = split_with_sizes[0] getitem_1 = split_with_sizes[1] getitem_2 = split_with_sizes[2]; split_with_sizes = None return (getitem, getitem_1, getitem_2)""") # noqa: B950 def test_invalidate_nonzero(self): ok = False def f(a): nonlocal ok b = a.clone() x = b.nonzero() x1 = b.nonzero() x2 = b.nonzero() assert x1.shape[0] == x2.shape[0] ok = True b.normal_() y = b.nonzero() try: bool(x1.shape[0] == y.shape[0]) self.fail("didn't raise exception") except GuardOnDataDependentSymNode: pass make_fx(f, tracing_mode="symbolic")(torch.randn(4)) @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) def test_invalidate_nonzero_propagate_real_tensors(self): def f(a): b = a.clone() x = b.nonzero() x1 = b.nonzero() x2 = b.nonzero() assert x1.shape[0] == x2.shape[0] b.normal_() y = b.nonzero() # Because you're not actually going to generate exactly zero with # normal_ lol assert x1.shape[0] == y.shape[0] make_fx(f, tracing_mode="symbolic")(torch.randn(4)) def test_sqrt_size(self): def f(a): return a / a.size(-1) ** 0.5 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) sym_float = torch.sym_float(sym_size_int); sym_size_int = None pow_1 = sym_float ** 0.5; sym_float = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self): class Bar(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x + 1 class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.bar = Bar() def forward(self, x): return x + self.bar(x) gm = make_fx(Foo())(torch.randn(4, 4)) for node in gm.graph.nodes: self.assertTrue("nn_module_stack" not in node.meta) foo = Foo() def functional_call(*args, **kwargs): with stateless._reparametrize_module(foo, {}): return foo(*args, **kwargs) functional_call._orig_mod = foo gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4)) found = False for node in gm_with_stack.graph.nodes: if "nn_module_stack" in node.meta: if len(node.meta["nn_module_stack"]) == 1: self.assertTrue("custom_tracer_preserving_nn_module_stack..Foo" in str(node.meta["nn_module_stack"])) found = True elif len(node.meta["nn_module_stack"]) == 2: self.assertTrue("preserving_nn_module_stack..Bar" in str(node.meta["nn_module_stack"])) found = True else: # there can be at most 2 level self.assertTrue(False) self.assertTrue(found) gm_without_stack = make_fx(functional_call)(torch.randn(4, 4)) for node in gm_without_stack.graph.nodes: self.assertTrue("nn_module_stack" not in node.meta) def test_symint_to_tensor(self): def f(a): return a / a.shape[0] r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) div = torch.ops.aten.div.Tensor(a_1, sym_size_int); a_1 = sym_size_int = None return div""") r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) sym_float = torch.sym_float(sym_size_int); sym_size_int = None div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None return div""") def test_cat(self): def f(a, b): val = torch.mul(a, b) out = torch.cat([val, val]) if out.shape[0] * out.shape[1] > 20: out = out.cos() return out test_inputs = [] test_inputs.append([(1, 5), (6, 1)]) test_inputs.append([(1, 4), (3, 1)]) gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) self.assertExpectedInline(show_guards(gm), """2*L['b'].size()[0]*L['a'].size()[1] > 20""") def test_new_empty(self): def f(a, b): return a.new_empty(b.shape[0], b.shape[1] * 2) self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env def test_size_with_tensor(self): # I think I messed up writing this test case originally, I think # I'm supposed to hit an error case, but the code here works in both # eager and tracing def f(tensor): max_size = torch.tensor([800, 1216], dtype=torch.int64) batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size) return tensor.new_empty(batch_shape) a = torch.randn(3, 800, 1199) f(a) make_fx(f, tracing_mode="symbolic")(a) def test_fake_tensor_as_size(self): def f(x): r = torch.zeros([x]) return r fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4)) self.assertExpectedInline(fx_g.code.strip(), """\ def forward(self, x_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return zeros""") # noqa: B950 def test_expand(self): def f(a): b = torch.mul(a, a) c = b.expand(a.shape) return c self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]]) self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]]) def test_metadata(self): def f(a, b): d = a.new_empty(a.shape[0] + b.shape[0]) return d fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) meta_d = _get_node(fx_g, lambda x: x.target == operator.add) self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr) def test_metadata_fresh(self): def f(x): assert x.shape[0] == 3 return x.cos() fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3)) meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default) meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder') self.assertTrue(meta_cos.meta['val'].shape[0] == 3) # Checks if the input expr has been updated even though the constraint # happened afterwards self.assertTrue(meta_inp.meta['val'].shape[0] == 3) def test_elementwise_meta_with_sym_numbers(self): def f(x, offset, as_sym_float=False): x0 = x.size()[0] if as_sym_float: x0 = torch.sym_float(x0) return torch.add(x0, offset) fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) self.assertEqual(meta_add.meta['val'].shape, ()) self.assertEqual(meta_add.meta['val'].dtype, torch.float32) fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) self.assertEqual(meta_add.meta['val'].shape, ()) self.assertEqual(meta_add.meta['val'].dtype, torch.int64) fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) self.assertEqual(meta_add.meta['val'].shape, ()) self.assertEqual(meta_add.meta['val'].dtype, torch.float32) def test_return_symint(self): def f(x): return x.shape[0], x.cos(), x.shape[0] / 5 self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) def f(x): return x.shape self._test_dynamic(f, [(5, 3)], [[(4, 6)]]) def test_rmethod(self): def f(x): return x.size(0) + x self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) def test_mega_guard(self): def f(a, b): assert a.shape[0] == b.shape[0] * 2 return a.cos() fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) from torch._dynamo.source import LocalSource self.assertExpectedInline( str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950 """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950 ) self.assertExpectedInline( str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950 """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950 ) def test_guard_upperbound_range_refinement(self): def f(a): assert a.shape[0] > 5 and a.shape[0] > 12 return a.cos() tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15)) self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""") def test_guard_lowerbound_range_refinement(self): def f(a): assert a.shape[0] < 20 and a.shape[0] < 30 return a.cos() tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15)) self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""") def test_guard_upperbound_range_refinement_multivariate(self): def f(a): assert a.shape[0] > 5 and a.shape[0] > 12 assert a.shape[1] > 5 and a.shape[1] > a.shape[0] return a.cos() tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20))) self.assertExpectedInline(show_guards(tensor), """\ L['a'].size()[1] > L['a'].size()[0] 13 <= L['a'].size()[0] 14 <= L['a'].size()[1]""") def test_guard_lowerbound_range_refinement_multivariate(self): def f(a): assert a.shape[0] < 20 and a.shape[0] < 30 assert a.shape[1] < 30 and a.shape[1] < a.shape[0] return a.cos() tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5))) self.assertExpectedInline( show_guards(tensor), """\ L['a'].size()[1] < L['a'].size()[0] 3 <= L['a'].size()[0] and L['a'].size()[0] <= 19 L['a'].size()[1] <= 18""") def test_sym_storage_offset(self): def f(x, y): return x + y inp = (torch.randn(8)[3:], torch.randn(5)) fx_g = make_fx(f, tracing_mode="symbolic")(*inp) inp = (torch.randn(8)[3:], torch.randn(5)) self.assertEqual(fx_g(*inp), f(*inp)) def _assert_no_guards(self, fx_g, free_symbols): assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards() def test_guards_equal(self): def f(a, b): return a * b # NB: Numbers are carefully chosen to avoid duck shaping from applying fx_g = _trace(f, (5, 6), (5, 6)) self._assert_no_guards(fx_g, 2) fx_g = _trace(f, (5, 6, 7), (5, 6, 7)) self._assert_no_guards(fx_g, 3) fx_g = _trace(f, (5, 1), (1, 6)) self._assert_no_guards(fx_g, 2) def f(a, b, c, d): a = a + b cat = torch.cat([c, d]) return a + cat fx_g = _trace(f, 7, 7, 4, 3) self._assert_no_guards(fx_g, 2) def f(a, b, c, d, e): vals = [a, b, c, d, e] x = a for idx in range(len(vals) - 1): x = torch.cat([x, vals[idx]]) + vals[idx + 1] return x fx_g = _trace(f, 2, 4, 8, 16, 32) self._assert_no_guards(fx_g, 1) def f(a, b): a = a.view(b.shape[0]) return a + b.sum() fx_g = _trace(f, (4, 2), 8) self._assert_no_guards(fx_g, 2) fx_g = _trace(f, (4, 2), (8, 5)) self._assert_no_guards(fx_g, 3) fx_g = _trace(f, (2, 3, 4), 24) self._assert_no_guards(fx_g, 3) def test_nonidentity_transitive_guards(self): def f(a, b, c, d, e): vals = [a, b, c, d, e] cat_vals = [] for idx in range(len(vals) - 1): cat_vals.append(torch.cat([vals[idx], vals[idx]])) final_vals = [] for a, b in reversed(list(zip(cat_vals, vals[1:]))): final_vals.append(a + b) return final_vals fx_g = _trace(f, 2, 4, 8, 16, 32) self.assertExpectedInline(show_guards(fx_g), """""") @torch.fx.experimental._config.patch(translation_validation=True) def test_constant_specialization(self): def f(t): assert t.shape[0] == 10 return t tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10)) self.assertExpectedInline(show_guards(tensor), """""") make_fx_failures = { # unknown xfail('allclose'), xfail('equal'), # empty skip('new_empty'), skip('empty_like'), skip('empty'), skip('empty_permuted'), # flaky skip('linalg.lstsq', 'grad_oriented'), skip('nn.functional.max_unpool1d', '', device_type='cpu'), skip('nn.functional.max_unpool2d', '', device_type='cpu'), skip('nn.functional.max_unpool3d', '', device_type='cpu'), skip('linalg.lstsq'), # flaky, probably just a precision issue # data-dependent control flow skip('item'), xfail('cov'), xfail('nn.functional.gaussian_nll_loss'), xfail('corrcoef'), xfail('quantile'), xfail('nanquantile'), # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse xfail('sparse.sampled_addmm'), xfail('sparse.mm', 'reduce'), # proxy tensor doesn't support sparse correctly right now skip('to_sparse'), # segfaults skip('block_diag'), # AssertionError: Tensor-likes are not close! skip('empty_strided', '', device_type='cpu'), } only_real_tensor_failures = { xfail('narrow'), xfail('tensor_split'), } only_fake_tensor_failures = { xfail('narrow'), xfail('tensor_split'), } fake_tensor_failures = set() symbolic_tensor_failures = { xfail('combinations', ''), xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... } symbolic_tensor_segfaults = { skip('nn.functional.batch_norm') # Segfault?? } symbolic_tensor_failures.update(symbolic_tensor_segfaults) inplace_symbolic_tensor_failures = { # bugs xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double } out_symbolic_tensor_failures = { # Cast error details: Unable to cast (...) to Tensor # # This happens because the test is set up to call the out variant using the `out` kwarg: # torch._some_op(arg1, arg2, out=(out1, out2, out3)) # # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`, # this fails because the op has no python bindings, so it doesn't support the `out` kwarg # way of calling its out variant. xfail('_batch_norm_with_update', ''), xfail('_native_batch_norm_legit', ''), xfail('angle', ''), xfail('argmax', ''), xfail('argmin', ''), xfail('gather', ''), xfail('linalg.pinv', ''), xfail('linalg.pinv', 'hermitian'), xfail('scatter_add', ''), xfail('scatter', ''), xfail('take_along_dim', ''), # SymIntArrayRef expected to contain only concrete xfail('randn', ''), # RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides xfail('index_reduce', 'prod'), xfail('index_reduce', 'mean'), xfail('index_reduce', 'amax'), xfail('index_reduce', 'amin'), } out_symbolic_tensor_segfaults = { skip('nanmean', ''), } out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults) # Copies inputs to inplace operations to avoid inplace modifications # to leaves requiring gradient def _get_safe_inplace(inplace_variant): @functools.wraps(inplace_variant) def _fn(t, *args, **kwargs): return inplace_variant(t.clone(), *args, **kwargs) return _fn def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False): fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long count = 100 if out: count = 5 for sample_input in itertools.islice(sample_inputs_itr, count): if inplace and sample_input.broadcasts_input: continue args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs if out: expected = fn(*args, **kwargs) kwargs['out'] = expected try: optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual, randomize_data=True) except DynamicOutputShapeException: self.skipTest("Dynamic output shape operation in trace") def skipIfNameMatches(pattern): """ Decorator to skip a test if its name matches the given pattern. """ def decorator(test_func): def wrapper(*args, **kwargs): if re.match(pattern, test_func.__name__): raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'") return test_func(*args, **kwargs) return wrapper return decorator # Auto functionalize shouldn't work with make_fx directly filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"] @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo") class TestProxyTensorOpInfo(TestCase): @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures.union(only_real_tensor_failures)) def test_make_fx_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "real") @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures, only_fake_tensor_failures)) def test_make_fx_fake_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "fake") @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op): if not op.get_inplace(): self.skipTest("No inplace variable for this op") _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True) @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out', make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op): if not op.supports_out: self.skipTest("Op doesn't support out") _test_make_fx_helper(self, device, dtype, op, "symbolic", out=True) only_for = ("cpu") instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for) if __name__ == '__main__': run_tests()