mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix torch.isin
decomposition for scalar inputs (#153216)
This patch fixes a corner case of `torch.isin` decompisition when both inputs are scalars. This pattern showed up from #141196. Fixes #141196. Error stack befor this patch: ``` File "/home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py", line 12503, in test_scalar_isin_decomposition res = opt_f() ^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 691, in _fn raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1618, in _call_user_compiler raise BackendCompilerFailed( File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1593, in _call_user_compiler compiled_fn = compiler_fn(gm, self.example_inputs()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__ compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/__init__.py", line 2365, in __call__ return compile_fx(model_, inputs_, config_patches=self.config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 2317, in compile_fx return aot_autograd( ^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_dynamo/backends/common.py", line 106, in __call__ cg = aot_module_simplified(gm, example_inputs, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1179, in aot_module_simplified compiled_fn = AOTAutogradCache.load( ^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 923, in load compiled_fn = dispatch_and_compile() ^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1164, in dispatch_and_compile compiled_fn, _ = create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function return _create_aot_dispatcher_function( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 826, in _create_aot_dispatcher_function compiled_fn, fw_metadata = compiler_fn( ^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 180, in aot_dispatch_base fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner t = dispatch_trace( ^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_compile.py", line 51, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/_symbolic_trace.py", line 850, in trace (self.create_arg(fn(*args)),), ^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped out = f(*tensors) # type:ignore[call-arg] ^^^^^^^^^^^ File "<string>", line 1, in <lambda> File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 720, in inner_fn outs = fn(*args) ^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 419, in _functionalized_f_helper f_outs = fn(*f_args) ^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 81, in inner_fn outs = fn(*args) ^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 902, in functional_call out = PropagateUnbackedSymInts(mod).run( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 171, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7387, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 240, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 320, in call_function return target(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_subclasses/functional_tensor.py", line 511, in __torch_dispatch__ outs_unwrapped = func._op_dk( ^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/utils/_stats.py", line 27, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1428, in __torch_dispatch__ return proxy_call(self, func, self.pre_dispatch, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 797, in proxy_call r = maybe_handle_decomp(proxy_mode, func, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 2358, in maybe_handle_decomp out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_prims_common/wrappers.py", line 309, in _fn result = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_decomp/decompositions.py", line 5108, in isin return isin_default(elements, test_elements, invert=invert) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ryanguo99/repos/pytorch/torch/_decomp/decompositions.py", line 5137, in isin_default x = elements.view(*elements.shape, *((1,) * test_elements.ndim)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: TypeError: view() received an invalid combination of arguments - got (), but expected one of: * (torch.dtype dtype) * (tuple of ints size) While executing %isin : [num_users=1] = call_function[target=torch.isin](args = (%x, %x), kwargs = {}) GraphModule: class GraphModule(torch.nn.Module): def forward(self): # File: /home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py:12498 in f, code: x = torch.tensor(0) x: "i64[][]" = torch.tensor(0) # File: /home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py:12499 in f, code: return torch.isin(x, x) isin: "b8[][]" = torch.isin(x, x); x = None return (isin,) Original traceback: File "/home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py", line 12499, in f return torch.isin(x, x) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153216 Approved by: https://github.com/williamwen42, https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
180cbf46f2
commit
3976e52264
@ -12502,6 +12502,16 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
|
||||
f(torch.tensor([30, 30], device=device), torch.tensor([68, 32], device=device))
|
||||
|
||||
def test_scalar_isin_decomposition(self):
|
||||
def f():
|
||||
x = torch.tensor(0)
|
||||
return torch.isin(x, x)
|
||||
|
||||
opt_f = torch.compile(f, backend="inductor", fullgraph=True)
|
||||
ref = f()
|
||||
res = opt_f()
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)
|
||||
|
@ -5134,7 +5134,8 @@ def bernoulli(
|
||||
def isin_default(elements, test_elements, *, invert=False):
|
||||
if elements.numel() == 0:
|
||||
return torch.empty_like(elements, dtype=torch.bool)
|
||||
x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
|
||||
expanded_elem_shape = elements.shape + (1,) * test_elements.ndim
|
||||
x = elements.view(expanded_elem_shape)
|
||||
dim = tuple(range(-1, -test_elements.ndim - 1, -1))
|
||||
res = (x == test_elements).any(dim=dim)
|
||||
return ~res if invert else res
|
||||
|
Reference in New Issue
Block a user