diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index aa2d38067f54..5e9f70345ae3 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -1121,10 +1121,6 @@ class CompileTest(TestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_cache() def test_inductor_all_to_all_single(self): - def _tolist(tensor): - lst = tensor.tolist() - return lst - def func( input: torch.Tensor, output_split_sizes: torch.Tensor, @@ -1132,8 +1128,8 @@ class CompileTest(TestCase): ) -> torch.Tensor: output = funcol.all_to_all_single( input, - _tolist(output_split_sizes), - _tolist(input_split_sizes), + output_split_sizes.tolist(), + input_split_sizes.tolist(), "0", ) return funcol.wait_tensor(output) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 12f201c6010b..94ff7a05df74 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -53,11 +53,6 @@ from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils._python_dispatch import TorchDispatchMode -def _tolist(tensor): - lst = tensor.tolist() - return lst - - @requires_accelerator_dist_backend(["nccl", "xccl"]) @instantiate_parametrized_tests class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): @@ -535,8 +530,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): ranks, group_size, ): - input_split_sizes = _tolist(input_split_sizes_tensor) - output_split_sizes = _tolist(output_split_sizes_tensor) + input_split_sizes = input_split_sizes_tensor.tolist() + output_split_sizes = output_split_sizes_tensor.tolist() a2a = torch.ops.c10d_functional.all_to_all_single( inp, output_split_sizes, @@ -696,8 +691,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): ranks, group_size, ): - input_split_sizes = _tolist(input_split_sizes_tensor) - output_split_sizes = _tolist(output_split_sizes_tensor) + input_split_sizes = input_split_sizes_tensor.tolist() + output_split_sizes = output_split_sizes_tensor.tolist() a2a = torch.ops.custom_ns.alltoall_autograd.default( inp, output_split_sizes, diff --git a/test/export/test_export.py b/test/export/test_export.py index 66ceba968ccd..29b4922be1f4 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -973,7 +973,7 @@ graph(): model = M() ep = export(model, (torch.randint(0, 8, (5,), dtype=torch.int64),)) - print(ep) + inp = torch.randint(0, 8, (5,), dtype=torch.int64) self.assertTrue(torch.allclose(ep.module()(inp), M()(inp))) @@ -2666,7 +2666,7 @@ class GraphModule(torch.nn.Module): m = M() x = torch.randn(3) ep = export(m, (x,)) - print(ep) + ufm = torch.export.unflatten(ep) self.assertExpectedInline( str(ufm.graph_module.code).strip(), @@ -3571,7 +3571,6 @@ graph(): sample_input = _tensor(nz=nz) ep = export(mod, (sample_input,), strict=False) self.assertEqual(ep.module()(sample_input), nz) - print(ep) def test_export_script_module(self): class Foo(torch.nn.Module): @@ -8902,7 +8901,6 @@ def forward(self, x): def forward(self, start_pos: torch.Tensor): pos = start_pos.item() - torch._check_is_size(pos) torch._check(pos >= 0) torch._check(pos <= 4) return self.freq[pos] * self.freq[pos] @@ -8911,17 +8909,11 @@ def forward(self, x): FileCheck().check_count( "torch.ops.aten._assert_scalar.default", 2, exactly=True ).run(ep.graph_module.code) - FileCheck().check_count( - "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True - ).run(ep.graph_module.code) decompose_ep = ep.run_decompositions() FileCheck().check_count( "torch.ops.aten._assert_scalar.default", 2, exactly=True ).run(ep.graph_module.code) - FileCheck().check_count( - "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True - ).run(ep.graph_module.code) def test_mixed_input(self): class Module(torch.nn.Module): @@ -14360,7 +14352,6 @@ def forward(self, x, y): sin = torch.ops.aten.sin.default(y) sum_1 = torch.ops.aten.sum.dim_IntList(sin, []); sin = None _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x); x = None - sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size_default = None ge_1 = _local_scalar_dense >= 3 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1 = _local_scalar_dense <= 5 @@ -16781,15 +16772,7 @@ class TestOneOffModelExportResult(TestCase): with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) - print(ep.graph) ep.run_decompositions() - print(ep.graph) - - # self.assertExpectedInline(ep.graph_module.code.strip(), """\ - # def forward(self, arg0_1, arg1_1, arg2_1): - # _scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(arg0_1, arg1_1, arg2_1, 0.0, True); arg0_1 = arg1_1 = arg2_1 = None - # getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None - # return (getitem,)""") @skipIfCrossRef @unittest.skipIf( diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 1cedac3bdf38..e3c551213277 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14489,7 +14489,6 @@ if RUN_GPU: return a[y.to(torch.int64)] def fn2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - torch._check_is_size(b.shape[0]) torch._check(b.shape[0] >= 2) torch._check(b.shape[0] <= 100) return fn1(a, b) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 9cd7f7ccb9ac..5eaa007a8a1c 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -285,7 +285,6 @@ class TestInductorDynamic(TestCase): def f(): full = torch.full((), 11) i0 = full.item() - torch._check_is_size(i0) return torch.full((i0,), 0) opt_f = torch.compile(f, fullgraph=True) @@ -450,8 +449,6 @@ class TestInductorDynamic(TestCase): def test_return_unbacked_view_split(self, device): def f(values, length_per_key): u0, u1 = length_per_key.tolist() - torch._check_is_size(u0) - torch._check_is_size(u1) v1, v2 = torch.functional.split(values, [u0, u1]) return v1, v2 @@ -483,7 +480,6 @@ class TestInductorDynamic(TestCase): @torch.library.register_fake("_test::_cat") def _cat_fake(t: torch.Tensor, ds: list[int]) -> torch.Tensor: - [torch._check_is_size(d) for d in ds] return t.new_empty([sum(ds)]) def _cat_setup_context(ctx, inputs, output): @@ -983,7 +979,6 @@ class TestInductorDynamic(TestCase): @torch.compile(fullgraph=True, dynamic=True) def f(x): a = x.item() - torch._check_is_size(a) torch._check(a >= 1) torch._check(a <= 10) return torch.ones(a, a) @@ -995,8 +990,6 @@ class TestInductorDynamic(TestCase): @torch.compile() def f(xt): xs = xt.tolist() - for x in xs: - torch._check_is_size(x) y = sum(xs) return torch.zeros(y, device=device) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 597905a8caa5..c8eb20bffb32 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -1075,7 +1075,6 @@ class CommonTemplate: def foo(x, length): unbacked = length.item() - torch._check_is_size(unbacked) repeated = x.repeat(1, unbacked, NUM_REPEAT) # permute creates split in middle with unbacked symint is the first range diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index dff293410fd8..48242e4ed719 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -361,11 +361,11 @@ class TestUnbackedSymints(InductorTestCase): inp = args[0] start = inp.slice_bounds[0].item() - torch._check_is_size(start) + torch._check(start >= 0) torch._check(start <= inp.size(0)) length = (args[0].slice_bounds[1] - args[0].slice_bounds[0]).item() - torch._check_is_size(length) + torch._check(length >= 0) torch._check(start + length <= inp.size(0)) return CustomSliceSubclass( diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index fc2570999686..18c6ac5945e5 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1453,7 +1453,7 @@ def tensor_split_tensor_indices_or_sections_py_impl( # To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these. with ctx(): indices = [i.item() for i in tensor_indices_or_sections] - # WARNING: Tempted to torch._check_is_size on the indices here? You + # WARNING: Tempted to torch._check(x>0) on the indices here? You # can't: tensor_split works with negative values in indices: # # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5])) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index c572f900cfff..0e88b145d951 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -308,7 +308,7 @@ optimize_ddp: Union[ ], ] = True -# By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph. +# By default, Dynamo emits runtime asserts (e.g. torch._check) in the graph. # In some cases those asserts could be performance costly # E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync. # Setting this to True keeps them hinting to symbolic shapes engine, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 925d6f468f77..8df7e9300d01 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -841,7 +841,7 @@ def sym_constrain_range_for_size(size, min=None, max=None): from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size if min is None and max is None: - torch._check_is_size(size) + torch._check(size >= 0) return if isinstance(size, (SymFloat, SymBool)): @@ -6835,7 +6835,7 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True): # From aten/src/ATen/native/Sorting.cpp dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) sliceSize = 1 if self.dim() == 0 else self.size(dim) - torch._check_is_size(k) + torch._check(k >= 0) torch._check(k <= sliceSize, lambda: "k not in range for dimension") topKSize = list(self.shape) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index f52c8cbfdfb6..2afb23310486 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -728,7 +728,7 @@ def validate_dim_length(length: int): """ if isinstance(length, (int, torch.SymInt)): - torch._check_is_size(length) + torch._check(length >= 0) else: # sometimes called with sympy expression by inductor assert length >= 0 @@ -1083,13 +1083,7 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]: # PyTorch, which prints sequences in square brackets. shape = list(shape) shape[dim] = numel // newsize - # NB: This is pretty important when you have unbacked SymInts. - # Suppose you have (i0, 12) resizing into (2, -1, 12). The old - # range for i0 is typically [2, inf], which means if you divide - # by two the new range should be [1, inf]. But this is bad news - # if you have an unbacked SymInt: we need to reapply the unsound - # assumption that the size is >= 2. - torch._check_is_size(shape[dim]) + torch._check(shape[dim] >= 0) return tuple(shape) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index c8a72924c1d8..0b000cfa1a9a 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -822,7 +822,8 @@ def slice_forward( # create unbacked if case unknown if new_size is None: new_size = shape_env.create_unbacked_symint() - torch._check_is_size(new_size, max=sizes[dim]) + torch._check(new_size >= 0) + torch._check(new_size <= sizes[dim]) # stride new_stride = strides[dim] * step diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index c893794fc301..83a1b86b109b 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -944,7 +944,7 @@ def _all_to_all_single_meta( return input.new_empty(input.size()) else: for s in output_split_sizes: - torch._check_is_size(s) + torch._check(s >= 0) out_size = list(input.size()) out_size[0] = sum(output_split_sizes) return input.new_empty(out_size)