Fix torch.isin decomposition for scalar inputs (#153216)

This patch fixes a corner case of `torch.isin` decompisition when both
inputs are scalars. This pattern showed up from #141196.

Fixes #141196.

Error stack befor this patch:
```
  File "/home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py", line 12503, in test_scalar_isin_decomposition
    res = opt_f()
          ^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 691, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1618, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1593, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/__init__.py", line 2365, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 2317, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/backends/common.py", line 106, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1179, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 923, in load
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1164, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 826, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 180, in aot_dispatch_base
    fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph(  # type: ignore[misc]

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 720, in inner_fn
    outs = fn(*args)
           ^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 419, in _functionalized_f_helper
    f_outs = fn(*f_args)
             ^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 81, in inner_fn
    outs = fn(*args)
           ^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 902, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7387, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 320, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_subclasses/functional_tensor.py", line 511, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1428, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 797, in proxy_call
    r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 2358, in maybe_handle_decomp
    out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_prims_common/wrappers.py", line 309, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_decomp/decompositions.py", line 5108, in isin
    return isin_default(elements, test_elements, invert=invert)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ryanguo99/repos/pytorch/torch/_decomp/decompositions.py", line 5137, in isin_default
    x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: view() received an invalid combination of arguments - got (), but expected one of:
 * (torch.dtype dtype)
 * (tuple of ints size)

While executing %isin : [num_users=1] = call_function[target=torch.isin](args = (%x, %x), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self):
         # File: /home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py:12498 in f, code: x = torch.tensor(0)
        x: "i64[][]" = torch.tensor(0)

         # File: /home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py:12499 in f, code: return torch.isin(x, x)
        isin: "b8[][]" = torch.isin(x, x);  x = None
        return (isin,)

Original traceback:
  File "/home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py", line 12499, in f
    return torch.isin(x, x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153216
Approved by: https://github.com/williamwen42, https://github.com/peterbell10
This commit is contained in:
Ryan Guo
2025-05-08 16:47:03 -07:00
committed by PyTorch MergeBot
parent 180cbf46f2
commit 3976e52264
2 changed files with 12 additions and 1 deletions

View File

@ -12502,6 +12502,16 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
f(torch.tensor([30, 30], device=device), torch.tensor([68, 32], device=device))
def test_scalar_isin_decomposition(self):
def f():
x = torch.tensor(0)
return torch.isin(x, x)
opt_f = torch.compile(f, backend="inductor", fullgraph=True)
ref = f()
res = opt_f()
self.assertEqual(ref, res)
devices = ("cuda", "hpu")
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)

View File

@ -5134,7 +5134,8 @@ def bernoulli(
def isin_default(elements, test_elements, *, invert=False):
if elements.numel() == 0:
return torch.empty_like(elements, dtype=torch.bool)
x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
expanded_elem_shape = elements.shape + (1,) * test_elements.ndim
x = elements.view(expanded_elem_shape)
dim = tuple(range(-1, -test_elements.ndim - 1, -1))
res = (x == test_elements).any(dim=dim)
return ~res if invert else res