# Owner(s): ["module: inductor"] import sys import unittest import unittest.mock as mock import torch import torch._inductor from torch._higher_order_ops import foreach_map from torch._inductor.test_case import TestCase from torch._inductor.utils import run_fw_bw_and_get_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_FBCODE, parametrize, ) from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.utils._pytree import tree_flatten aten = torch.ops.aten try: try: from .test_torchinductor import check_model, check_model_cuda except ImportError: from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library check_model, check_model_cuda, ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": sys.exit(0) raise def foreach_map_wrapper(op): def wrapper(*args, **kwargs): return foreach_map(op, *args, **kwargs) wrapper.__name__ = "foreach_map_" + op.__name__ wrapper.original_op = op return wrapper def add_op(x, y): return torch.add(x, y) def add_inplace_op(x, y): x.add_(y) return x.sin() def addrecip_op(x, y): return torch.reciprocal(torch.add(x, y)) def addcmul_op(x, y, z): return torch.mul(torch.add(x, y), z) def recipaddmul_op(x, y, z): return torch.mul(torch.add(torch.reciprocal(x), y), z) # Foreach map bin op defs which support a scalar arg foreach_map_add = foreach_map_wrapper(torch.add) foreach_map_mul = foreach_map_wrapper(torch.mul) foreach_map_sub = foreach_map_wrapper(torch.sub) foreach_map_div = foreach_map_wrapper(torch.div) foreach_map_addrecip = foreach_map_wrapper(addrecip_op) foreach_map_clamp_max = foreach_map_wrapper(torch.clamp_max) foreach_map_clamp_min = foreach_map_wrapper(torch.clamp_min) # No scalar args (due to limitations on the op itself) foreach_map_max = foreach_map_wrapper(torch.maximum) foreach_map_min = foreach_map_wrapper(torch.minimum) foreach_map_copy = foreach_map_wrapper(aten.copy) # More general functions foreach_map_add_fn = foreach_map_wrapper(add_op) foreach_map_add_inplace = foreach_map_wrapper(add_inplace_op) foreach_map_recipaddmul = foreach_map_wrapper(addrecip_op) foreach_map_addcmul = foreach_map_wrapper(addcmul_op) foreach_map_recipaddmul = foreach_map_wrapper(recipaddmul_op) # Foreach map unary op defs foreach_map_recip = foreach_map_wrapper(torch.reciprocal) foreach_map_neg = foreach_map_wrapper(torch.neg) foreach_map_sign = foreach_map_wrapper(torch.sign) foreach_map_abs = foreach_map_wrapper(torch.abs) inplace_bin_ops_under_test = [ torch._foreach_add_, torch._foreach_mul_, torch._foreach_sub_, torch._foreach_div_, ] ternary_ops_under_test = [ foreach_map_addcmul, foreach_map_recipaddmul, ] foreach_map_bin_ops_under_test = [ foreach_map_add, foreach_map_mul, foreach_map_sub, foreach_map_div, foreach_map_addrecip, foreach_map_clamp_max, foreach_map_clamp_min, foreach_map_add_fn, foreach_map_max, foreach_map_min, ] foreach_map_un_ops_under_test = [ foreach_map_recip, foreach_map_neg, foreach_map_sign, foreach_map_abs, ] bin_ops_under_test = [ torch._foreach_add, torch._foreach_mul, torch._foreach_sub, torch._foreach_div, torch._foreach_maximum, torch._foreach_minimum, torch._foreach_clamp_max, torch._foreach_clamp_min, aten._foreach_copy, foreach_map_copy, # aten.copy doesn't support backward *foreach_map_bin_ops_under_test, ] scalar_bin_ops_under_test = [ op for op in bin_ops_under_test if op not in (foreach_map_max, foreach_map_min, foreach_map_copy, aten._foreach_copy) ] un_ops_under_test = [ torch._foreach_reciprocal, torch._foreach_neg, torch._foreach_sign, torch._foreach_abs, torch._foreach_sqrt, torch._foreach_rsqrt, *foreach_map_un_ops_under_test, ] compose_ops = [torch._foreach_addcdiv, torch._foreach_addcmul] all_ops = parametrize( "op", ternary_ops_under_test + bin_ops_under_test + un_ops_under_test, name_fn=lambda f: f.__name__, ) bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__) inplace_bin_ops = parametrize( "op", inplace_bin_ops_under_test, name_fn=lambda f: f.__name__ ) scalar_bin_ops = parametrize( "op", scalar_bin_ops_under_test, name_fn=lambda f: f.__name__ ) scalar_tensor_bin_ops = parametrize( "op", scalar_bin_ops_under_test, name_fn=lambda f: f.__name__ ) foreach_map_bin_ops = parametrize( "op", foreach_map_bin_ops_under_test, name_fn=lambda f: f.__name__ ) foreach_map_un_ops = parametrize( "op", foreach_map_un_ops_under_test, name_fn=lambda f: f.__name__ ) decomp_ops = parametrize("op", compose_ops, name_fn=lambda f: f.__name__) def gen_args(op): if op in un_ops_under_test: return ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) elif op in bin_ops_under_test: return ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) else: return ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) @instantiate_parametrized_tests class ForeachTests(TestCase): check_model_cuda = check_model_cuda check_model_cpu = check_model check_kernel_count = True def setUp(self): super().setUp() torch._inductor.metrics.reset() def tearDown(self): super().tearDown() torch._inductor.metrics.reset() def _test_single_list(self, op): if op in un_ops_under_test: def fn(a0, a1): return op([a0, a1]) elif op in bin_ops_under_test: def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) else: def fn(a0, a1, b0, b1, c0, c1): return op([a0, a1], [b0, b1], [c0, c1]) self.check_model_cuda( fn, gen_args(op), ) def _test_single_scalar(self, op): def fn(a0, a1): return op([a0, a1], 3.3) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), ) def _test_single_scalar_tensor(self, op): def fn(a0, a1): return op([a0, a1], torch.tensor(3.3, device="cuda:0")) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), ) # called in test_cuda_cpp_wrapper.py @requires_cuda_and_triton def test_foreach_cpp_wrapper_cuda(self): self._test_single_list(op=torch._foreach_add) @requires_cuda_and_triton @all_ops def test_single_list(self, op): self._test_single_list(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_single_scalar(self, op): self._test_single_scalar(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_tensor_bin_ops def test_single_scalar_tensor(self, op): self._test_single_scalar_tensor(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @all_ops def test_scheduler_fusion_list(self, op): if op in un_ops_under_test: def fn(a0, a1): c = op([a0, a1]) return torch._foreach_sqrt(c) elif op in bin_ops_under_test: def fn(a0, a1, b0, b1): c = op([a0, a1], [b0, b1]) return c, torch._foreach_add([a0, a1], c) else: def fn(a0, a1, b0, b1, c0, c1): c = op([a0, a1], [b0, b1], [c0, c1]) return c, torch._foreach_add([a0, a1], c) self.check_model_cuda( fn, gen_args(op), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_scheduler_fusion_scalar(self, op): def fn(a0, a1): c = op([a0, a1], 3.4) return c, torch._foreach_add([a0, a1], c) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_broadcasting(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) fn_opt = torch.compile(fn) inputs = ( torch.rand(10, 1, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(1, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) actual = fn_opt(*inputs) expected = fn(*inputs) self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @all_ops def test_singleton_lists(self, op): if op in un_ops_under_test: def fn(a0): return op([a0]) args = (torch.rand(10, 10, device="cuda:0"),) elif op in bin_ops_under_test: def fn(a0, b0): return op([a0], [b0]) args = ( torch.rand(10, 10, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), ) else: def fn(a0, b0, c0): return op([a0], [b0], [c0]) args = ( torch.rand(10, 10, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), ) self.check_model_cuda( fn, args, ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @bin_ops def test_type_promotion(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) fn_opt = torch.compile(fn) max32 = torch.iinfo(torch.int32).max max64 = torch.iinfo(torch.int64).max inputs = ( torch.randint(max32, (10, 10), device="cuda:0", dtype=torch.int32), torch.randint(max32, (20, 20), device="cuda:0", dtype=torch.int32), torch.randint(max32, (10, 10), device="cuda:0", dtype=torch.int32), torch.randint(max64, (20, 20), device="cuda:0", dtype=torch.int64), ) actual = fn_opt(*inputs) expected = fn(*inputs) self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_kernel_split_arg_limit_list(self, op): # NB: foeach_copy won't pass this test because it will dce one set of buffers def fn(a, b): return op(a, b) fn_opt = torch.compile(fn) max_args = 370 max_list_len = (max_args // 3) + 1 inputs = ( [torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)], [torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)], ) actual = fn_opt(*inputs) expected = fn(*inputs) self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @scalar_bin_ops @unittest.skip( "Triton recursion depth exceeded: https://github.com/triton-lang/triton/issues/1763" ) def test_kernel_split_arg_limit_scalar(self, op): def fn(a): return op(a, 3.3) fn_opt = torch.compile(fn) max_args = 370 max_list_len = (max_args // 2) + 1 inputs = ([torch.rand(10, 10, device="cuda:0") for _ in range(max_list_len)],) actual = fn_opt(*inputs) expected = fn(*inputs) self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @bin_ops def test_fusion_duplicate_buffer_list(self, op): def fn(a0, a1, b0, b1): c = op([a0, a1], [b0, b1]) return op([a0, b0], [c[0], c[0]]) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), reference_in_float=False, check_lowp=False, ) kernel_count = 1 if "foreach_map" in op.__name__: kernel_count = 2 self.assertEqual(torch._inductor.metrics.generated_kernel_count, kernel_count) @requires_cuda_and_triton @all_ops def test_non_foreach_consumer_list(self, op): if op in un_ops_under_test: def fn(a0, a1): c = op([a0, a1]) return torch.mul(c[0], a0) elif op in bin_ops_under_test: def fn(a0, a1, b0, b1): c = op([a0, a1], [b0, b1]) return torch.mul(c[0], a0) else: def fn(a0, a1, b0, b1, c0, c1): c = op([a0, a1], [b0, b1], [c0, c1]) return torch.mul(c[0], a0) self.check_model_cuda( fn, gen_args(op), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_consumer_scalar(self, op): def fn(a0, a1): c = op([a0, a1], 4.7) return torch.mul(c[0], a0) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @all_ops def test_non_foreach_producer_list(self, op): if op in un_ops_under_test: def fn(a0, a1): c0 = torch.add(a0, a0) c1 = torch.add(a1, a1) return op([c0, c1]) elif op in bin_ops_under_test: def fn(a0, a1, b0, b1): c0 = torch.add(a0, b0) c1 = torch.add(a1, b1) return op([a0, a1], [c0, c1]) else: def fn(a0, a1, b0, b1, c0, c1): c0 = torch.add(a0, b0) c1 = torch.add(a1, b1) return op([a0, a1], [b0, b1], [c0, c1]) self.check_model_cuda( fn, gen_args(op), reference_in_float=False, check_lowp=False ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_producer_scalar(self, op): def fn(a0, a1, b0, b1): c0 = torch.mul(a0, b0) c1 = torch.mul(a1, b1) return op([c0, c1], 5.6) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @all_ops def test_non_foreach_consumer_producer_list(self, op): if op in un_ops_under_test: def fn(a0, a1): c0 = torch.add(a0, a0) c1 = torch.mul(a1, a1) d = op([c0, c1]) e0 = torch.mul(d[0], a0) e1 = torch.mul(d[1], a1) return [e0, e1] elif op in bin_ops_under_test: def fn(a0, a1, b0, b1): c0 = torch.add(a0, b0) c1 = torch.add(a1, b1) d = op([a0, a1], [c0, c1]) e0 = torch.mul(d[0], a0) e1 = torch.mul(d[1], a1) return [e0, e1] else: def fn(a0, a1, b0, b1, c0, c1): c0 = torch.add(a0, b0) c1 = torch.add(a1, b1) d = op([a0, a1], [b0, b1], [c0, c1]) e0 = torch.mul(d[0], a0) e1 = torch.mul(d[1], a1) return [e0, e1] self.check_model_cuda( fn, gen_args(op), reference_in_float=False, check_lowp=False, ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @scalar_bin_ops def test_non_foreach_consumer_producer_scalar(self, op): def fn(a0, a1, b0, b1): c0 = torch.add(a0, b0) c1 = torch.add(a1, b1) d = op([c0, c1], 5.8) e0 = torch.mul(d[0], a0) e1 = torch.mul(d[1], a1) return [e0, e1] self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), reference_in_float=False, check_lowp=False, ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @bin_ops @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", False) def test_dynamic_shapes_fallback(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) inputs = ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) self.check_model_cuda(fn, inputs) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) def test_enable_dynamic_shapes_python_wrapper(self, op=torch._foreach_add): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) inputs = ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) self.check_model_cuda(fn, inputs) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_enable_dynamic_shapes_cpp_wrapper_cuda(self, op=torch._foreach_add): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) inputs = ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) self.check_model_cuda(fn, inputs) @unittest.skipIf(IS_FBCODE, "cpp compile not supported in fbcode") @bin_ops def test_cpu_cpp_fallback(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) inputs = ( torch.rand(10, 10, device="cpu"), torch.rand(20, 20, device="cpu"), torch.rand(10, 10, device="cpu"), torch.rand(20, 20, device="cpu"), ) self.check_model_cpu(fn, inputs) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @decomp_ops def test_decomp(self, op): def fn(a0, a1, b0, b1, c0, c1): return op([a0, a1], [b0, b1], [c0, c1], value=0.5) self.check_model_cuda( fn, ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton def test_fuse_concat(self): def fn(x1, x2, x3, w1, w2, w3): x = torch.stack([x1, x2, x3]) w = torch.stack([w1, w2, w3]) y = torch.bmm(x, w) return y x1 = torch.randn(5, 4).cuda() x2 = x1 + 1 x3 = x1 + 2 w1 = torch.randn(4, 3).cuda() w2 = w1 + 1 w3 = w1 + 2 args = (x1, x2, x3, w1, w2, w3) self.check_model_cuda(fn, args) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton def test_zero_elems(self): def fn(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) self.check_model_cuda( fn, ( torch.rand(0, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(0, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @bin_ops def test_2d_blocking(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) self.check_model_cuda( fn, ( torch.rand(10, 40, device="cuda:0"), torch.rand(10, 30, device="cuda:0"), torch.rand(40, 10, device="cuda:0").t(), torch.rand(30, 10, device="cuda:0").t(), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @bin_ops def test_2d_blocking_partitioning(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) self.check_model_cuda( fn, ( torch.rand(30, 20, device="cuda:0"), torch.rand(40, 30, device="cuda:0"), torch.rand(30, 20, device="cuda:0"), torch.rand(30, 40, device="cuda:0").t(), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @bin_ops def test_2d_blocking_partitioning_elems(self, op): """2D blocking should be grouped by number of yelems""" def fn(a0, a1, a2, b0, b1, b2): return op([a0, a1, a2], [b0, b1, b2]) self.check_model_cuda( fn, ( torch.rand(10, 20, device="cuda:0"), torch.rand(30, 20, device="cuda:0"), torch.rand(10, 30, device="cuda:0"), torch.rand(20, 10, device="cuda:0").t(), torch.rand(20, 30, device="cuda:0").t(), torch.rand(30, 10, device="cuda:0").t(), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @bin_ops @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_blocking_partitioning_mixed_sizes(self, op): """2D blocking with mixed sizes should group together""" def fn(a0, a1, a2, b0, b1, b2): return op([a0, a1, a2], [b0, b1, b2]) self.check_model_cuda( fn, ( torch.rand(10, 20, device="cuda:0"), torch.rand(30, 20, device="cuda:0"), torch.rand(10, 30, device="cuda:0"), torch.rand(20, 10, device="cuda:0").t(), torch.rand(20, 30, device="cuda:0").t(), torch.rand(30, 10, device="cuda:0").t(), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing(self, op): def fn(a0, a1, b0, b1): op([a0, a1], [b0, b1]) return [a0, a1] inputs = ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) self.check_model_cuda(fn, inputs, check_lowp=False) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing_mut_before(self, op): def fn(a0, a1, b0, b1): a0.add_(torch.ones(10, 10, device="cuda:0")) op([a0, a1], [b0, b1]) return [a0, a1] inputs = ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) self.check_model_cuda(fn, inputs, check_lowp=False) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @inplace_bin_ops def test_reinplacing_mut_after(self, op): def fn(a0, a1, b0, b1): op([a0, a1], [b0, b1]) a0.add_(torch.ones(10, 10, device="cuda:0")) return [a0, a1] inputs = ( torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), torch.rand(10, 10, device="cuda:0"), torch.rand(20, 20, device="cuda:0"), ) self.check_model_cuda(fn, inputs, check_lowp=False) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton def test_multi_device(self): def test_foreach_add(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) inps = [ torch.ones(10, 10, device="cuda"), torch.ones(20, 20, device="cpu"), torch.zeros(10, 10, device="cuda"), torch.zeros(20, 20, device="cpu"), ] out_eager = test_foreach_add(*inps) out_compiled = torch.compile(test_foreach_add)(*inps) self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton def test_aliasing(self): def test_foreach_add(a0, a1, a2, b0, b1, b2): return torch._foreach_add_([a0, a1, a2], [b0, b1, b2]) input = torch.ones(10, 10, device="cuda") input2 = torch.ones(10, 10, device="cuda") inps = [ input, input.view(10, 10), input.view(10, 10), input2, input2.view(10, 10), input2.view(10, 10), ] out_eager = test_foreach_add(*inps) out_compiled = torch.compile(test_foreach_add)(*inps) self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 1) def test_2d_block_no_mixed_sizes_no_mask(self): """2D blocking with no mixed sizes constant mask""" def fn(a0, a1, a2, b0, b1, b2): return torch._foreach_add([a0, a1, a2], [b0, b1, b2]) self.check_model_cuda( fn, ( torch.rand(1024, 2048, device="cuda:0"), torch.rand(2048, 2048, device="cuda:0"), torch.rand(1024, 2048, device="cuda:0"), torch.rand(2048, 1024, device="cuda:0").t(), torch.rand(2048, 2048, device="cuda:0").t(), torch.rand(2048, 1024, device="cuda:0").t(), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) @requires_cuda_and_triton @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_block_mixed_sizes_with_mask(self): """2D blocking with mixed sizes should have mask""" def fn(a0, a1, a2, b0, b1, b2): return torch._foreach_add([a0, a1, a2], [b0, b1, b2]) self.check_model_cuda( fn, ( torch.rand(1024, 2048, device="cuda:0"), torch.rand(2048, 2048, device="cuda:0"), torch.rand(1024, 2048, device="cuda:0"), torch.rand(2048, 1024, device="cuda:0").t(), torch.rand(2048, 2048, device="cuda:0").t(), torch.rand(2048, 1024, device="cuda:0").t(), ), ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda_and_triton @foreach_map_bin_ops def test_foreach_map_backward_binary(self, op): from torch._dynamo.polyfills import foreach_map_fn def fn(xs, ys): outs = op(xs, ys) return outs[0].sum() + outs[1].sum() + outs[2].sum() def ref_fn(xs, ys): outs = foreach_map_fn(torch.add, xs, ys) return outs[0].sum() + outs[1].sum() + outs[2].sum() ref_inps = ( [ torch.rand(10, 20, device="cuda:0", requires_grad=True), torch.rand(10, 30, device="cuda:0", requires_grad=True), torch.rand(30, 30, device="cuda:0", requires_grad=True), ], [ torch.rand(10, 20, device="cuda:0", requires_grad=True), torch.rand(10, 30, device="cuda:0", requires_grad=True), torch.rand(30, 30, device="cuda:0", requires_grad=True), ], ) inps = ( [x.clone().detach().requires_grad_(True) for x in ref_inps[0]], [y.clone().detach().requires_grad_(True) for y in ref_inps[1]], ) out_ref = ref_fn(*ref_inps) out_ref.backward() # unpacking result, (fw_code, bw_code) _, (_, _) = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps)) for ref, act in zip(tree_flatten(ref_inps)[0], tree_flatten(inps)[0]): torch.allclose(ref.grad, act.grad) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) @requires_cuda_and_triton def test_foreach_map_input_mutation(self): def fn(xs, ys): outs = foreach_map_add_inplace(xs, ys) return outs[0].sum() + outs[1].sum() + outs[2].sum() ref_inps = ( [ torch.rand(10, 20, device="cuda:0", requires_grad=True), torch.rand(10, 30, device="cuda:0", requires_grad=True), torch.rand(30, 30, device="cuda:0", requires_grad=True), ], [ torch.rand(10, 20, device="cuda:0", requires_grad=True), torch.rand(10, 30, device="cuda:0", requires_grad=True), torch.rand(30, 30, device="cuda:0", requires_grad=True), ], ) # Set requires_grad to be False to avoid mutating a leaf variable inps = ( [x.clone().detach().requires_grad_(False) for x in ref_inps[0]], [y.clone().detach().requires_grad_(False) for y in ref_inps[1]], ) # TODO: after decomposing auto_functionalized, we're getting # a functional subgraph with an inlined epilogue. with self.assertRaisesRegex( torch._inductor.exc.InductorError, "Buffer mutation detected during lowering of aten.copy_.default", ): with mock.patch( "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", True, ): _ = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps)) @requires_cuda_and_triton @foreach_map_un_ops def test_foreach_map_backward_unary(self, op): from torch._dynamo.polyfills import foreach_map_fn def fn(xs): outs = op(xs) return outs[0].sum() + outs[1].sum() + outs[2].sum() def ref_fn(xs): outs = foreach_map_fn(op.original_op, xs) return outs[0].sum() + outs[1].sum() + outs[2].sum() ref_inp = [ torch.rand(10, 20, device="cuda:0", requires_grad=True), torch.rand(10, 30, device="cuda:0", requires_grad=True), torch.rand(30, 30, device="cuda:0", requires_grad=True), ] inp = [x.clone().detach().requires_grad_(True) for x in ref_inp] out_ref = ref_fn(ref_inp) out_ref.backward() # unpacking result, (fw_code, bw_code) _, (_, _) = run_fw_bw_and_get_code(lambda: torch.compile(fn)(inp)) for ref, act in zip(ref_inp, inp): torch.allclose(ref.grad, act.grad) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU or HAS_CUDA_AND_TRITON: run_tests(needs="filelock")