diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 15037d70a0d1..aadc664b93f3 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -198,12 +198,22 @@ class OperatorInputsMode(TorchDispatchMode): def map_to_device(e, device): - return e.to(device) if isinstance(e, torch.Tensor) else e + if isinstance(e, torch.Tensor): + return e.to(device) + elif isinstance(e, torch.device): + return device + elif isinstance(e, str): + if e == "cuda" or e == "cpu": + return device.type + else: + return e def map_to_dtype(e, dtype): if isinstance(e, torch.Tensor) and e.is_floating_point(): return e.to(dtype) + elif isinstance(e, torch.dtype): + return dtype else: return e diff --git a/benchmarks/dynamo/microbenchmarks/operatorbench.py b/benchmarks/dynamo/microbenchmarks/operatorbench.py index fcc15bf5d932..1b806ebdd8f5 100644 --- a/benchmarks/dynamo/microbenchmarks/operatorbench.py +++ b/benchmarks/dynamo/microbenchmarks/operatorbench.py @@ -2,7 +2,6 @@ import click import numpy as np import torch -import triton from operator_inp_utils import OperatorInputsLoader from torch._dynamo.optimizations.backends import cudagraphs_inner @@ -17,7 +16,7 @@ aten = torch.ops.aten def compute_speedups( - operator, models, example_inputs, repeats, accuracy_checking=False + operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda" ): expected = models[0](*example_inputs) if accuracy_checking: @@ -35,10 +34,19 @@ def compute_speedups( for rep in range(repeats): # interleave the runs to handle frequency scaling and load changes for m, model in enumerate(models): - # do_bench() clears L2 cache to hide the latency of CPU launch time - # along with cuda synchronization - median_ms, _, _ = triton.testing.do_bench(lambda: model(*example_inputs)) - timings[rep, m] = median_ms + if device == "cuda": + import triton + + # do_bench() clears L2 cache to hide the latency of CPU launch time + # along with cuda synchronization + median_ms, _, _ = triton.testing.do_bench( + lambda: model(*example_inputs) + ) + timings[rep, m] = median_ms + else: + from torch._inductor.utils import timed + + timings[rep, m] = timed(model, example_inputs) return np.median(timings, axis=0) @@ -64,15 +72,19 @@ def convert_to_jit(gm, gm_args): def microbenchmark( - operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser + operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser, device ): gm, gm_args = gen_gm_and_inputs(operator, args, kwargs) torch.jit._builtins._register_builtin( torch.ops.aten.convolution_backward.default, "aten::convolution_backward" ) - cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False) - compiled_fn = compile_fx(gm, gm_args) - compiled = [cudagraphs_eager, compiled_fn] + if device == "cuda": + cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False) + compiled_fn = compile_fx(gm, gm_args) + compiled = [cudagraphs_eager, compiled_fn] + else: + compiled_fn = compile_fx(gm, gm_args) + compiled = [gm, compiled_fn] if measure_nvfuser: g = convert_to_jit(gm, gm_args) cudagraphs_jit = cudagraphs_inner(g, gm_args, copy_outputs=False) @@ -80,7 +92,9 @@ def microbenchmark( if accuracy_checking: repeats = 1 - medians = compute_speedups(operator, compiled, gm_args, repeats, accuracy_checking) + medians = compute_speedups( + operator, compiled, gm_args, repeats, accuracy_checking, device + ) return medians @@ -147,8 +161,9 @@ def skip_operator(operator): @click.option( "--measure-nvfuser", help="default we only measure inductor", default=False ) +@click.option("--device", help="cpu or cuda", default="cuda") def benchmark( - suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser + suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser, device ): assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}" if suite == "timm": @@ -176,7 +191,7 @@ def benchmark( continue print(f"Running {operator}") - inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype) + inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device) timings = [] for i in range(min(max_samples, 1000000)): @@ -199,6 +214,7 @@ def benchmark( accuracy_checking, repeats, measure_nvfuser, + device, ) ) except Exception as e: diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 9b1297a129ae..b872fee1af6b 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -237,6 +237,8 @@ class TorchBenchmarkRunner(BenchmarkRunner): if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE: batch_size = USE_SMALL_BATCH_SIZE[model_name] + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" + torch.backends.__allow_nonbracketed_mutation_flag = True if is_training: benchmark = benchmark_cls( test="train", device=device, jit=False, batch_size=batch_size diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index b185313f8b14..9e468a76f433 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -4,6 +4,7 @@ import functools import torch import torch._dynamo +import torch._dynamo.test_case from torch._dynamo.optimizations.training import is_aot_autograd_safe_to_run from torch._dynamo.testing import rand_strided @@ -13,7 +14,7 @@ def compiler_safe_fn(gm, example_inputs, is_safe): return gm.forward -class AotAutogradFallbackTests(torch._dynamo.testing.TestCase): +class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase): def test_LSTM(self): # https://github.com/pytorch/torchdynamo/issues/1147 class Repro(torch.nn.Module): @@ -60,6 +61,65 @@ class AotAutogradFallbackTests(torch._dynamo.testing.TestCase): aot_fn(x, y) self.assertTrue(not is_safe[0]) + def test_mutation1(self): + def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): + getitem = diagonal_chunked_attention_scores[ + ( + slice(None, None, None), + slice(None, None, None), + slice(None, 256, None), + slice(None, 257, None), + ) + ] + _stack0[ + ( + slice(None, None, None), + slice(None, -1, None), + slice(None, None, None), + slice(256, None, None), + ) + ] = getitem + view = _stack0.view(1, 12, 1024, 513) + return (view,) + + x = torch.randn(torch.Size([12, 4, 256, 513])) + y = torch.randn(torch.Size([12, 3, 512, 513])) + is_safe = [True] + compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) + aot_fn = torch._dynamo.optimize(compiler_fn)(fn) + aot_fn(x, y) + self.assertTrue(not is_safe[0]) + + def test_negative_testing_mutation(self): + def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): + getitem = diagonal_chunked_attention_scores[ + ( + slice(None, None, None), + slice(None, None, None), + slice(None, 256, None), + slice(None, 257, None), + ) + ] + _stack0 = torch.sin(_stack0) + _stack0[ + ( + slice(None, None, None), + slice(None, -1, None), + slice(None, None, None), + slice(256, None, None), + ) + ] = getitem + view = _stack0.view(1, 12, 1024, 513) + return (view,) + + x = torch.randn(torch.Size([12, 4, 256, 513])) + y = torch.randn(torch.Size([12, 3, 512, 513])) + is_safe = [True] + compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) + aot_fn = torch._dynamo.optimize(compiler_fn)(fn) + aot_fn(x, y) + self.assertTrue(is_safe[0]) + def test_negative_testing(self): def fn(x, y): return torch.sin(x).add_(y) @@ -74,6 +134,6 @@ class AotAutogradFallbackTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index 37eeb6af3b30..fdb7c88762b8 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -7,8 +7,10 @@ from unittest.mock import patch import torch import torch._dynamo +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same +from torch.testing._internal.common_utils import TEST_WITH_ROCM def composed(*decs): @@ -43,6 +45,7 @@ def assert_aot_autograd_counter(ok=True): def patch_all(ok=True): return composed( + unittest.skipIf(TEST_WITH_ROCM, "ROCm not supported"), patch("torch._dynamo.config.verify_correctness", True), assert_aot_autograd_counter(ok), ) @@ -52,7 +55,7 @@ N_ITERS = 5 @unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") -class TestAotCudagraphs(torch._dynamo.testing.TestCase): +class TestAotCudagraphs(torch._dynamo.test_case.TestCase): @patch_all() def test_basic(self): def model(x, y): @@ -201,6 +204,6 @@ class TestAotCudagraphs(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py index c1684a013d71..695e34817f37 100644 --- a/test/dynamo/test_distributed.py +++ b/test/dynamo/test_distributed.py @@ -7,6 +7,7 @@ import pytest import torch import torch._dynamo +import torch._dynamo.test_case import torch.distributed as dist from torch import nn from torch._dynamo import config @@ -43,7 +44,7 @@ def skip_if_no_active_ddp(): @pytest.mark.skip("Module hangs in PyTorch CI") -class TestDistributed(torch._dynamo.testing.TestCase): +class TestDistributed(torch._dynamo.test_case.TestCase): """ Test harness initializes dist process group """ @@ -209,6 +210,63 @@ class TestDistributed(torch._dynamo.testing.TestCase): opt_outputs.sum().backward() self.assertTrue(same(correct_outputs, opt_outputs)) + @patch.object(config, "optimize_ddp", True) + def test_custom_layer(self): + """ + Just ensures that the appropriate number of splits happen (based on + bucket size and model parameters) - verifies the number of times + the user-provided compiler is called by the DDPOptimizer which is + doing the graph splitting + """ + from torch.nn.parallel import DistributedDataParallel as DDP + + skip_if_no_active_ddp() + + class MyCustomLinear(torch.nn.Module): + def __init__(self): + super(MyCustomLinear, self).__init__() + self.weight = nn.Parameter(torch.randn(512, 512)) + + def forward(self, x): + return torch.mm(x, self.weight.t()) + + class MyLinear(torch.nn.Module): + def __init__(self): + super(MyLinear, self).__init__() + self.linear = torch.nn.Linear(512, 512) + + def forward(self, x): + return self.linear(x) + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + mods = [ + (MyLinear(), torch.nn.ReLU()), + # sandwitch the custom in the middle so it comes before and after + (MyCustomLinear(), torch.nn.ReLU()), + (MyLinear(), torch.nn.ReLU()), + ] + self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) + + def forward(self, x): + return self.seq(x) + + m = MyModule().to(self.device) + inputs = torch.randn((512, 512)).to(self.device) + correct_outputs = m(inputs) + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) + + check_splits_compiler = CheckSplitsCompiler() + + @torch._dynamo.optimize(check_splits_compiler.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 3) + def test_empty_graph(self): def fn(): get_world_size = torch.distributed.distributed_c10d.get_world_size() diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 2c9c90df19e0..a2a94fce1e55 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -25,6 +25,6 @@ DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests) DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests) if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 9365535c73bc..5347805cb5ed 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3,12 +3,13 @@ from unittest.mock import patch import torch +import torch._dynamo.test_case import torch._dynamo.testing import torch.utils._pytree as pytree from torch.fx.experimental.proxy_tensor import make_fx -class ExportTests(torch._dynamo.testing.TestCase): +class ExportTests(torch._dynamo.test_case.TestCase): # TODO(voz): Refactor to a shared test function. # The tests in this file are a little redundant, # They all take a func, run it with eager, then export it, then compare @@ -1423,6 +1424,6 @@ class ExportTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index e2004430f418..d18ef7e1173f 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -9,6 +9,7 @@ from typing import Any import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch import sub from torch._dynamo.testing import requires_static_shapes @@ -53,7 +54,7 @@ def inline_unused(x): return x + 5.6 -class FunctionTests(torch._dynamo.testing.TestCase): +class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_inline_jit_annotations(x): x = inline_script_if_tracing(x) @@ -670,6 +671,6 @@ class FunctionTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index 5e3d975d7bc8..445a6cf103d4 100644 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same @@ -43,7 +44,7 @@ def reset_name(): _name = 0 -class TestGlobals(torch._dynamo.testing.TestCase): +class TestGlobals(torch._dynamo.test_case.TestCase): def test_store_global_1(self): def fn(x): global g_counter @@ -227,6 +228,6 @@ class TestGlobals(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 030b9f73ecf3..4570d15b2d14 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -6,6 +6,7 @@ from unittest.mock import patch import torch import torch._dynamo +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.optimizations.backends import create_backend @@ -23,7 +24,7 @@ class MockModule(torch.nn.Module): return x -class MinfierTests(torch._dynamo.testing.TestCase): +class MinfierTests(torch._dynamo.test_case.TestCase): def test_after_dynamo(self): @create_backend def bad_dynamo_backend(subgraph): @@ -92,6 +93,6 @@ class MinfierTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e3e05059230f..2d3db456c8e0 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -16,6 +16,7 @@ from unittest.mock import patch import numpy as np import torch +import torch._dynamo.test_case import torch._dynamo.testing import torch.onnx.operators from torch._dynamo import bytecode_transformation @@ -34,7 +35,7 @@ def my_custom_function(x): return x + 1 -class MiscTests(torch._dynamo.testing.TestCase): +class MiscTests(torch._dynamo.test_case.TestCase): def test_boolarg(self): def boolarg(aa, bb, flag): if flag: @@ -2719,6 +2720,6 @@ class TestTracer(JitTestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index 28fdbbb8e596..7a35728b6116 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -4,6 +4,7 @@ import unittest.mock import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same @@ -22,7 +23,7 @@ def maybe_skip(fn): return fn -class TestHFPretrained(torch._dynamo.testing.TestCase): +class TestHFPretrained(torch._dynamo.test_case.TestCase): @maybe_skip def test_pretrained(self): def fn(a, tmp): @@ -38,7 +39,7 @@ class TestHFPretrained(torch._dynamo.testing.TestCase): self.assertTrue(same(ref, res)) -class TestModelOutput(torch._dynamo.testing.TestCase): +class TestModelOutput(torch._dynamo.test_case.TestCase): @maybe_skip def test_mo_create(self): def fn(a, b): @@ -160,6 +161,6 @@ class TestModelOutput(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 6d05026499a7..bfba78b5fc10 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -5,6 +5,7 @@ from unittest.mock import patch import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker @@ -604,7 +605,7 @@ def make_test(fn, expected_ops=None): return test_fn -class NNModuleTests(torch._dynamo.testing.TestCase): +class NNModuleTests(torch._dynamo.test_case.TestCase): test_seq = make_test(Seq()) test_basicmodule1 = make_test(BasicModule()) test_basicmodule2 = make_test(BasicModule()) @@ -884,6 +885,6 @@ class NNModuleTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index 6b2faec3d1d5..d65166f5762c 100644 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -24,6 +24,6 @@ NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests) NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py index de52315e12ef..44e102699d09 100644 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo import eval_frame @@ -35,7 +36,7 @@ with_debug_nops = eval_frame._optimize_catch_errors( ) -class NopTests(torch._dynamo.testing.TestCase): +class NopTests(torch._dynamo.test_case.TestCase): @with_debug_nops def test1(self): self.assertEqual(fn1(1, 2), -7) @@ -66,6 +67,6 @@ class NopTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py index b58d7a44e599..d9f25c595499 100644 --- a/test/dynamo/test_optimizations.py +++ b/test/dynamo/test_optimizations.py @@ -8,6 +8,7 @@ from unittest.mock import patch import torch import torch._dynamo +import torch._dynamo.test_case from torch._dynamo.optimizations import backends from torch._dynamo.optimizations.analysis import has_mutation from torch._dynamo.optimizations.log_args import conv_args_analysis @@ -64,7 +65,7 @@ class Conv_Bn_Relu(torch.nn.Module): return self.relu(self.bn(self.conv(x))) -class TestOptimizations(torch._dynamo.testing.TestCase): +class TestOptimizations(torch._dynamo.test_case.TestCase): def test_inplacifier(self): gm = torch.fx.symbolic_trace(Seq()) normalize(gm) @@ -183,7 +184,7 @@ class TestOptimizations(torch._dynamo.testing.TestCase): self.assertEqual(r2.dtype, torch.bfloat16) -class NormalizeIRTests(torch._dynamo.testing.TestCase): +class NormalizeIRTests(torch._dynamo.test_case.TestCase): @unittest.skipIf(not has_functorch(), "requires functorch") def test_inplace_normalize(self): def fn(a, b): @@ -202,6 +203,6 @@ class NormalizeIRTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 122c5c06b069..13668a730c60 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -6,6 +6,7 @@ import unittest import torch import torch._dynamo +import torch._dynamo.test_case import torch._dynamo.testing input = torch.ones([10, 10]) @@ -37,7 +38,7 @@ def make_test(optim_cls, exp_frame_cnt=1, closure=None, **kwargs): return test_fn -class OptimizerTests(torch._dynamo.testing.TestCase): +class OptimizerTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() @@ -97,6 +98,6 @@ for opt in optimizers: setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt)) if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index fe2f2819f20d..73a6f8f6d330 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -4,7 +4,8 @@ from typing import Callable, Dict, List, NamedTuple, Optional import torch import torch._dynamo -from torch._dynamo.testing import CompileCounter, same, TestCase +from torch._dynamo.test_case import run_tests, TestCase +from torch._dynamo.testing import CompileCounter, same """ This is an example of a pure-python version of autograd implemented by @@ -283,6 +284,4 @@ class TestPythonAutograd(TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests - run_tests() diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index 00e99ab3f202..b39bea3ce932 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -6,10 +6,11 @@ import torch import torch._dynamo import torch._dynamo.config +import torch._dynamo.test_case import torch._dynamo.testing -class RecompileUxTests(torch._dynamo.testing.TestCase): +class RecompileUxTests(torch._dynamo.test_case.TestCase): # TODO(whc) dynamo actualy recompiles one more time than the cache limit cache_limit = 1 diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index f2586b7db37e..4404337d10e8 100644 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -6,6 +6,7 @@ import unittest import torch +import torch._dynamo.test_case import torch._dynamo.testing try: @@ -16,7 +17,7 @@ except ImportError: requires_dill = unittest.skipIf(dill is None, "requires dill") -class ReplayRecordTests(torch._dynamo.testing.TestCase): +class ReplayRecordTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() @@ -181,6 +182,6 @@ class ReplayRecordTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index db44b20cfd31..2bd3130958eb 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -14,6 +14,7 @@ from unittest.mock import patch import numpy as np import torch +import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils from torch import nn @@ -749,7 +750,7 @@ class TestModule(torch.nn.Module): return self.inner_fn(tensor.shape, (1, 2, 3)) -class ReproTests(torch._dynamo.testing.TestCase): +class ReproTests(torch._dynamo.test_case.TestCase): def test_do_paste_mask(self): torch._dynamo.utils.counters.clear() opt__do_paste_mask = torch._dynamo.optimize( @@ -1712,6 +1713,6 @@ class ReproTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py index a2338c60af8b..1d19762e73f8 100644 --- a/test/dynamo/test_skip_non_tensor.py +++ b/test/dynamo/test_skip_non_tensor.py @@ -4,10 +4,11 @@ from unittest.mock import patch import torch import torch._dynamo +import torch._dynamo.test_case from torch._dynamo.testing import CompileCounter -class SkipNonTensorTests(torch._dynamo.testing.TestCase): +class SkipNonTensorTests(torch._dynamo.test_case.TestCase): def test_add_tensor1(self): def fn(a, b): return a + b @@ -107,6 +108,6 @@ class SkipNonTensorTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index f7d601c82b70..3a38561f16d2 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -4,6 +4,7 @@ from unittest.mock import patch import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo import config from torch._dynamo.testing import unsupported @@ -17,7 +18,7 @@ def indirectly_unsupported(a, b): return unsupported(a, c) -class SubGraphTests(torch._dynamo.testing.TestCase): +class SubGraphTests(torch._dynamo.test_case.TestCase): def _common(self, fn, frame_count, op_count): torch._dynamo.reset() v1 = torch.ones(10) @@ -528,6 +529,6 @@ class SubGraphTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 5f184834418d..fbf398366193 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -7,6 +7,7 @@ from unittest.mock import patch import numpy as np import torch +import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same @@ -51,7 +52,7 @@ UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests) @patch.object(torch._dynamo.config, "specialize_int_float", False) -class UnspecTests(torch._dynamo.testing.TestCase): +class UnspecTests(torch._dynamo.test_case.TestCase): def test_numpy_correctness(self): def fn(x, y, z): xy = [x + y, y, False] @@ -221,6 +222,6 @@ class UnspecTests(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py index f9d820f44c29..8e3624bfd9e7 100644 --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -8,6 +8,7 @@ import torch import torch._dynamo import torch._dynamo.config as config +import torch._dynamo.test_case from torch._dynamo.optimizations import backends from torch._dynamo.testing import same @@ -77,7 +78,7 @@ def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm -class TestVerifyCorrectness(torch._dynamo.testing.TestCase): +class TestVerifyCorrectness(torch._dynamo.test_case.TestCase): @patch.object(config, "verify_correctness", True) def test_example_inputs(self): def fn(a, bc, d): @@ -169,6 +170,6 @@ class TestVerifyCorrectness(torch._dynamo.testing.TestCase): if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 47e7e4c41722..be4808f91083 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6,6 +6,7 @@ import importlib import random import sys import unittest +import weakref from unittest.mock import patch import torch @@ -19,6 +20,7 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, TestCase as TorchTestCase, ) +from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_flatten, tree_unflatten try: @@ -74,7 +76,6 @@ if torch.cuda.is_available(): pass requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") - torch._inductor.config.triton.autotune = False # too slow @@ -242,7 +243,6 @@ def check_model( # for graph in exp[2]: # print("Graph", graph) assert called, "Ran graph without calling compile_fx" - assert type(actual) == type(correct) correct_flat, correct_spec = tree_flatten(correct) @@ -2004,7 +2004,11 @@ class CommonTemplate: def test_cat(self): def fn(a): tmp = a * 2 - return torch.cat((a, a[:, :4] + 1, a + 2), -1), torch.cat((tmp, tmp), 0) + return ( + torch.cat((a, a[:, :4] + 1, a + 2), -1), + torch.cat((tmp, tmp), 0), + torch.cat((tmp, tmp.double()), 0), + ) self.common( fn, @@ -3171,6 +3175,15 @@ class CommonTemplate: ], ) + # issue #1150 + def test_dense_mask_index(self): + def fn(x, y): + y = torch.ops.aten.select.int(y, 0, 2) + z = x * y + return z.sum() + + self.common(fn, [torch.randn(102400), torch.randn(3)]) + def test_new_empty_strided(self): def fn(a): return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1]).fill_(123) @@ -3714,10 +3727,10 @@ class CommonTemplate: ) compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)]) - out = compiled(torch.randn(8, 4, device=self.device)) + out = compiled([torch.randn(8, 4, device=self.device)]) self.assertEqual(out[0].shape, (16, 2)) - out = compiled(torch.randn(12, 4, device=self.device)) + out = compiled([torch.randn(12, 4, device=self.device)]) self.assertEqual(out[0].shape, (24, 2)) @requires_cuda() @@ -3744,6 +3757,63 @@ class CommonTemplate: ) self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) + @patch.object(config.triton, "mm", "aten") + def test_list_clearing(self): + + if self.device == "cpu": + contexts = [contextlib.nullcontext] + else: + contexts = [ + contextlib.nullcontext, + lambda: patch.object(config.triton, "cudagraphs", True), + ] + + for context in contexts: + with context(): + inps = [ + torch.rand([5, 5]).to(self.device), + torch.rand([5, 5]).to(self.device), + ] + inp_refs = [weakref.ref(inp) for inp in inps] + + def fn(x, y): + a = x + y + return (a @ a,) + + fn_fx = make_fx(fn)(inps[0], inps[1]) + fn_compiled = compile_fx_inner(fn_fx, inps) + + test_self = self + matmul_seen = False + + class TestRefMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + nonlocal inps + nonlocal inp_refs + nonlocal test_self + nonlocal matmul_seen + + # by matmul, inputs should be deallocated + if func is aten.mm.out: + matmul_seen = True + test_self.assertEqual(len(inps), 0) + test_self.assertIsNone(inp_refs[0]()) + test_self.assertIsNone(inp_refs[1]()) + + return func(*args, **kwargs) + + with TestRefMode(): + fn_compiled(inps) + + # for some reason, TorchDispatch doesnt capture the + # cuda mm call (even without cudagraphs) + if self.device == "cpu": + self.assertTrue(matmul_seen) + else: + self.assertEqual(len(inps), 0) + if HAS_CPU: @@ -3781,7 +3851,7 @@ if HAS_CPU: fn_fx = make_fx(fn)(x1, y) fn_compiled = compile_fx_inner(fn_fx, [x1, y]) fn(x2, y) - fn_compiled(x3, y) + fn_compiled([x3, y]) assert same(x2, x3) def test_no_op_squeeze(self): @@ -3872,7 +3942,7 @@ if HAS_CUDA: ] mod = make_fx(forward)(*inps) compiled = compile_fx_inner(mod, inps) - compiled(*inps) + compiled(inps) @patch.object(config, "fallback_random", True) def test_dtype_factory_issue(self): @@ -3888,7 +3958,7 @@ if HAS_CUDA: mod = make_fx(forward)() compiled = compile_fx_inner(mod, ()) - assert compiled()[0].device.type == "cuda" + assert compiled([])[0].device.type == "cuda" @patch.object(config.triton, "cudagraphs", True) def test_expanded_inputs_cudagraphs(self): @@ -3952,6 +4022,7 @@ if HAS_CUDA: if __name__ == "__main__": - from torch._dynamo.testing import run_tests + from torch._dynamo.test_case import run_tests - run_tests(needs="filelock") + if HAS_CPU or HAS_CUDA: + run_tests(needs="filelock") diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 2b8b166a35d5..17bdc552a844 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -33,7 +33,7 @@ try: from .test_torchinductor import check_model, check_model_cuda except ImportError: from test_torchinductor import check_model, check_model_cuda -except (unittest.SkipTest, ImportError) as e: +except (unittest.SkipTest, ImportError, AssertionError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": sys.exit(0) @@ -154,6 +154,9 @@ inductor_skips["cuda"] = { "jiterator_binary": {b8, f16, f32, f64, i32, i64}, "jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64}, "jiterator_unary": {b8, f16, f32, f64, i32, i64}, + # Triton bug leads to segfault + "nn.functional.softplus": {f64}, + "nn.functional.mish": {f64}, } inductor_expected_failures_single_sample = defaultdict(dict) @@ -372,24 +375,13 @@ inductor_expected_failures_single_sample["cuda"] = { inductor_gradient_expected_failures_single_sample = defaultdict(dict) inductor_gradient_expected_failures_single_sample["cuda"] = { - "amax": {f16, f32, f64}, - "amin": {f16, f32, f64}, "asin": {f16}, "cumprod": {f16}, "linalg.vector_norm": {f64, f64}, "linalg.householder_product": {f32}, "linalg.lu": {f32, f64}, "kron": {f16}, - "masked.amax": {f16, f32, f64}, - "masked.amin": {f16, f32, f64}, - "max.reduction_no_dim": {f16, f32, f64}, - "median": {f16, f32, f64}, - "min.reduction_no_dim": {f16, f32, f64}, - "nan_to_num": {f16, f32, f64}, - "nanmean": {f16, f32, f64}, - "nanmedian": {f16, f32, f64}, "nanquantile": {f32, f64}, - "nansum": {f16, f32, f64}, "native_batch_norm": {f16, f32, f64}, "native_layer_norm": {f16, f32, f64}, "nn.functional._scaled_dot_product_attention": {f16}, @@ -446,6 +438,7 @@ inductor_override_kwargs = { "new_empty_strided": {"assert_equal": False}, "randn": {"assert_equal": False}, ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, + ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, "gradient": {"check_gradient": False}, # segfault on check_gradient # Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error "linalg.solve_triangular": {"check_gradient": False}, @@ -461,6 +454,8 @@ inductor_all_samples = { "index_copy", "scatter_reduce.sum", "select_scatter", + "squeeze", + "unsqueeze", } diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 40933d7f120d..7a68b890e715 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -54,8 +54,6 @@ constant_functions = { torch.onnx.is_in_onnx_export: False, } -# root folder of the project -base_dir = dirname(dirname(dirname(abspath(__file__)))) # don't specialize on shapes and strides and put shape ops in graph dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1" @@ -152,6 +150,12 @@ dynamo_import = __name__.replace(".config", "") # How to import torchinductor, either torchinductor or torch.inductor inductor_import = dynamo_import.replace("dynamo", "inductor") +# root folder of the project +if "torch." in dynamo_import: + base_dir = dirname(dirname(dirname(abspath(__file__)))) +else: + base_dir = dirname(dirname(abspath(__file__))) + class _AccessLimitingConfig(ModuleType): def __setattr__(self, name, value): diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index ac56c0e26204..91e8e152e436 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -228,7 +228,7 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - compiled(*args) + compiled(args) """ ) ) @@ -293,7 +293,7 @@ def inductor_fails(fx_g, args, check_str=None): try: compile_mod = compile_fx_inner(fx_g, args) - compile_mod(*args) + compile_mod(args) except Exception as e: if check_str is not None and check_str not in repr(e): return False @@ -385,7 +385,7 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str): orig_graph = copy.deepcopy(gm.graph) assert config.repro_after in ("dynamo", "aot", None) - def deferred_for_real_inputs(*real_inputs): + def deferred_for_real_inputs(real_inputs): """ Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, example_inputs can be fake tensors. We can call compiler_fn (which is @@ -420,14 +420,14 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str): raise ValueError("Bad accuracy detected") else: # Call the compiled function with real inputs - return compiled_fn(*real_inputs) + return compiled_fn(real_inputs) else: try: # Call the compiler_fn - which is either aot_autograd or inductor # with fake inputs compiled_fn = compiler_fn(gm, example_inputs, **kwargs) # Call the compiled function with real inputs - return compiled_fn(*real_inputs) + return compiled_fn(real_inputs) except Exception as e: if config.repro_level == 1: dump_compiler_graph_state( @@ -441,6 +441,7 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str): if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs + compiled_fn._boxed_call = True else: compiled_fn = compiler_fn(gm, example_inputs, **kwargs) @@ -453,6 +454,8 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False): """ Runs a forward and possibly backward iteration for a given mod and args. """ + from functorch._src.aot_autograd import make_boxed_func + from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass gm = copy.deepcopy(gm) @@ -465,7 +468,14 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False): if hasattr(gm, "zero_grad"): gm.zero_grad(True) - out = gm(*args) + + # TorchInductor returned callable expects lists. So, boxing the call. + if not hasattr(gm, "_boxed_call"): + orig_named_parameters = gm.named_parameters + gm = make_boxed_func(gm) + gm.named_parameters = orig_named_parameters + + out = gm(args) if only_fwd: return out if requires_bwd_pass(out): @@ -775,7 +785,7 @@ def wrap_backend_debug(compiler_fn, compiler_name: str): else: try: compiled_gm = compiler_fn(gm, example_inputs, **kwargs) - run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs)) + run_fwd_maybe_bwd(compiled_gm, example_inputs) except Exception as exc: log.warning( "Compiled Fx GraphModule failed with following error. Setting up minifier." @@ -815,7 +825,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name): try: compiled_gm = compiler_fn(gm, example_inputs) - run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs)) + run_fwd_maybe_bwd(compiled_gm, example_inputs) raise ValueError("No issue was detected") except Exception as exc: orig_failure = str(exc) diff --git a/torch/_dynamo/optimizations/analysis.py b/torch/_dynamo/optimizations/analysis.py index ccd175bfdae3..b7557a82d744 100644 --- a/torch/_dynamo/optimizations/analysis.py +++ b/torch/_dynamo/optimizations/analysis.py @@ -10,7 +10,7 @@ from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._pytree import tree_map from .. import config -from ..utils import fake_tensors_available +from ..utils import clone_inputs, fake_tensors_available if fake_tensors_available: from torch._subclasses import FakeTensorMode # noqa: F401 @@ -52,6 +52,14 @@ class ShapeAliasingAndMutationProp(ShapeProp): n.meta["alias_groups"] = { self.tensor_alias_group(obj) for obj in self.extract_tensors(result) } + + if ( + not n.meta["alias_groups"] + and n.op == "call_function" + and n.target == operator.setitem + ): + n.meta["alias_groups"] = {self.tensor_alias_group(tensor_args[0])} + n.meta["mutates_alias_groups"] = { self.tensor_alias_group(tensor) for tensor, v1, v2 in zip(tensor_args, input_versions1, input_versions2) @@ -113,6 +121,10 @@ def has_mutation(gm, example_inputs, inputs_only=False): true, we only check for mutation of inputs""" # TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way. + # Clone the inputs such that intermediate tensors (not leaf tensors) with + # requires_grad to True are now converted to False to avoid Runtime Error + # like "leaf variable that requires grad is inplace modified" + example_inputs = clone_inputs(example_inputs) if fake_tensors_available and config.fake_tensor_propagation: with FakeTensorMode() as fake_mode: pass diff --git a/torch/_dynamo/optimizations/backends.py b/torch/_dynamo/optimizations/backends.py index 1ec5c774de11..abcb4290e782 100644 --- a/torch/_dynamo/optimizations/backends.py +++ b/torch/_dynamo/optimizations/backends.py @@ -681,9 +681,7 @@ def tvm_compile_inner( elif tuning_option == "meta_schedule": from os import path as osp - from tvm.meta_schedule import TuneConfig - from tvm.meta_schedule.database import JSONDatabase - from tvm.meta_schedule.tune import tune_relay + from tvm.contrib.torch import optimize_torch with tempfile.TemporaryDirectory() as work_dir: if log_file is not None: @@ -691,22 +689,16 @@ def tvm_compile_inner( log_file ), "TVM's meta_schedule requires a directory for storing log files." work_dir = log_file - lib: tvm.runtime.Module = tune_relay( - mod=mod, - params=params, - target=target, - config=TuneConfig( - strategy="evolutionary", - num_trials_per_iter=64, - max_trials_per_task=trials, - max_trials_global=trials, - ), + + lib = optimize_torch( + jit_mod, + example_inputs, + max_trials_global=20000, work_dir=work_dir, - database=JSONDatabase( - osp.join(work_dir, "workload.json"), - osp.join(work_dir, "records.json"), - ), + target=target, + max_trials_per_task=64, ) + elif tuning_option is None: # no autotuning (for debugging) with tvm.transform.PassContext(opt_level=10): @@ -716,34 +708,41 @@ def tvm_compile_inner( "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " "There are three available options including None, auto_scheduler and meta_schedule." ) + if tune_option != "meta_schedule": + m = graph_executor.GraphModule(lib["default"](dev)) - m = graph_executor.GraphModule(lib["default"](dev)) + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if nd_tensor.dtype == "bool": + # DLPack does not support boolean so it can't be handled by + # torch.utils.dlpack.from_pack. Workaround by going through + # numpy, although this brings additional data copy overhead. + return torch.from_numpy(nd_tensor.numpy()) + return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) - def to_torch_tensor(nd_tensor): - """A helper function to transfer a NDArray to torch.tensor.""" - if nd_tensor.dtype == "bool": - # DLPack does not support boolean so it can't be handled by - # torch.utils.dlpack.from_pack. Workaround by going through - # numpy, although this brings additional data copy overhead. - return torch.from_numpy(nd_tensor.numpy()) - return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) + def exec_tvm(*args): + args = [a.contiguous() for a in args] + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + m.set_input( + f"inp_{idx}", + tvm.nd.array(arg.numpy(), dev), + ) + m.run() + return [ + to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs()) + ] - def exec_tvm(*args): - args = [a.contiguous() for a in args] - for idx, arg in enumerate(args, 0): - if arg.dim() != 0: - if arg.requires_grad: - arg = arg.detach() - m.set_input( - f"inp_{idx}", - tvm.nd.array(arg.numpy(), dev), - ) - m.run() - return [ - to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs()) - ] + else: + + def exec_tvm(*args): + args = [a.contiguous() for a in args] + return lib(*args) return exec_tvm + except Exception: log.exception("tvm error") return jit_mod # explicit fall back to eager diff --git a/torch/_dynamo/optimizations/distributed.py b/torch/_dynamo/optimizations/distributed.py index 5948f9f03b79..f65c16483aec 100644 --- a/torch/_dynamo/optimizations/distributed.py +++ b/torch/_dynamo/optimizations/distributed.py @@ -43,14 +43,14 @@ class DDPOptimizer: bucket_actual_sizes = [] node_splits = [[]] for node in reversed(gm.graph.nodes): + if node.op == "output" or node.op == "placeholder": + continue + if bucket_bytes >= self.bucket_bytes_cap: bucket_actual_sizes.insert(0, bucket_bytes) bucket_bytes = 0 node_splits.insert(0, []) - if node.op == "output" or node.op == "placeholder": - continue - elif node.op == "call_module": target = gm.get_submodule(node.target) params_size_b = sum( @@ -62,6 +62,10 @@ class DDPOptimizer: ) bucket_bytes += params_size_b # print(f"accumulated {params_size_b} b from {node}") + elif node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if maybe_param.requires_grad: + bucket_bytes += maybe_param.storage().nbytes() else: # TODO(whc) confirm this: # (e.g. call_method, call_function aren't expected to 'have' parameters) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 7a739b741465..b730ddea1d88 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -229,7 +229,12 @@ class OutputGraph(fx.Tracer): return wrap_name(k) # create a new unique name - name = re.sub(r"[^a-zA-Z0-9]", "_", "_".join(map(str, names))) + name = "_".join(map(str, names)) + # e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv + name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) + # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv + name = re.sub(r"[^a-zA-Z0-9]", "_", name) + if not name or not name[0].isalpha(): name = "sub" + name base = name diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py new file mode 100644 index 000000000000..089e5053d062 --- /dev/null +++ b/torch/_dynamo/test_case.py @@ -0,0 +1,71 @@ +import contextlib +import importlib +import sys +from unittest.mock import patch + +import torch +import torch.testing +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + TEST_WITH_CROSSREF, + TEST_WITH_ROCM, + TEST_WITH_TORCHDYNAMO, + TestCase as TorchTestCase, +) + +from . import config, reset, utils + + +def run_tests(needs=()): + from torch.testing._internal.common_utils import run_tests + + if ( + TEST_WITH_TORCHDYNAMO + or IS_WINDOWS + or TEST_WITH_CROSSREF + or TEST_WITH_ROCM + or sys.version_info >= (3, 11) + ): + return # skip testing + + if isinstance(needs, str): + needs = (needs,) + for need in needs: + if need == "cuda" and not torch.cuda.is_available(): + return + else: + try: + importlib.import_module(need) + except ImportError: + return + run_tests() + + +class TestCase(TorchTestCase): + @classmethod + def tearDownClass(cls): + cls._exit_stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context( + patch.object(config, "raise_on_backend_error", True) + ) + cls._exit_stack.enter_context( + patch.object(config, "raise_on_ctx_manager_usage", True) + ) + + def setUp(self): + super().setUp() + reset() + utils.counters.clear() + + def tearDown(self): + for k, v in utils.counters.items(): + print(k, v.most_common()) + reset() + utils.counters.clear() + super().tearDown() diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 790de24e20e5..af3d28f46aba 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -1,19 +1,16 @@ import contextlib import dis import functools -import importlib import logging import os.path -import sys import types import unittest from unittest.mock import patch import torch -import torch.testing._internal.common_utils from torch import fx -from . import config, eval_frame, optimize_assert, reset, utils +from . import config, eval_frame, optimize_assert, reset from .bytecode_transformation import ( create_instruction, debug_checks, @@ -29,37 +26,6 @@ three = 3 log = logging.getLogger(__name__) -def run_tests(needs=()): - return # TEMPORARY: disable all tests - - from torch.testing._internal.common_utils import ( - IS_WINDOWS, - run_tests, - TEST_WITH_CROSSREF, - TEST_WITH_TORCHDYNAMO, - ) - - if ( - TEST_WITH_TORCHDYNAMO - or IS_WINDOWS - or TEST_WITH_CROSSREF - or sys.version_info >= (3, 11) - ): - return # skip testing - - if isinstance(needs, str): - needs = (needs,) - for need in needs: - if need == "cuda" and not torch.cuda.is_available(): - return - else: - try: - importlib.import_module(need) - except ImportError: - return - run_tests() - - def clone_me(x): if x is None: return None @@ -229,36 +195,6 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None) self.assertEqual(actual.op_count, expected_ops) -class TestCase(torch.testing._internal.common_utils.TestCase): - @classmethod - def tearDownClass(cls): - cls._exit_stack.close() - super().tearDownClass() - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._exit_stack = contextlib.ExitStack() - cls._exit_stack.enter_context( - patch.object(config, "raise_on_backend_error", True) - ) - cls._exit_stack.enter_context( - patch.object(config, "raise_on_ctx_manager_usage", True) - ) - - def setUp(self): - super().setUp() - reset() - utils.counters.clear() - - def tearDown(self): - for k, v in utils.counters.items(): - print(k, v.most_common()) - reset() - utils.counters.clear() - super().tearDown() - - def dummy_fx_compile(gm: fx.GraphModule, example_inputs): return gm.forward diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index bc05980e2657..fbd3648a8d59 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -42,6 +42,10 @@ from .constant import ConstantVariable from .lists import ShapeVariable, SizeVariable +class _missing: + pass + + class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" @@ -189,8 +193,9 @@ class TensorVariable(VariableTracker): elif istype(example_value, int) and proxy.node.target in ( torch.seed, operator.mod, - torch.distributed.get_rank, - torch.distributed.get_world_size, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), ): proxy.node.meta["example_value"] = example_value return DynamicShapeVariable(proxy, type(example_value), **options) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 062ca366a289..74f07f7ea578 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -654,6 +654,7 @@ class TritonKernel(Kernel): def indexing( self, index: sympy.Expr, + *, copy_shape=None, dense_indexing=False, ): @@ -686,9 +687,11 @@ class TritonKernel(Kernel): mask.append(f"{tree.prefix}mask") dense_mask.append(f"{tree.prefix}mask") - if (need_dense and not have_dense) or index == 0: + if (need_dense and not have_dense) or isinstance( + index, sympy.core.numbers.Integer + ): index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)" - if index == 0: + if isinstance(index, sympy.core.numbers.Integer): return index_str, "None" else: mask = dense_mask @@ -779,7 +782,7 @@ class TritonKernel(Kernel): def store(self, name, index, value, mode=None): var = self.args.output(name) - index, mask = self.indexing(index, value, dense_indexing=True) + index, mask = self.indexing(index, dense_indexing=True) if mode is None: line = f"tl.store({var} + ({index}), {value}, {mask})" elif mode == "atomic_add": @@ -861,7 +864,7 @@ class TritonKernel(Kernel): var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)] self.suffix.writeline(f"{result_var} = {var_name}") self.inside_reduction = False - index, mask = self.indexing(index, result_var) + index, mask = self.indexing(index) assert "rmask" not in index self.inside_reduction = True self.outside_loop_vars.add(result_var) diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py index 308b1c1f45d9..4d86feeccec8 100644 --- a/torch/_inductor/codegen/triton_template.py +++ b/torch/_inductor/codegen/triton_template.py @@ -77,7 +77,9 @@ class TritonTemplateKernel(TritonKernel): def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=True): # use dense_indexing for TritonTemplateKernel to avoid map::at error - return super().indexing(index, copy_shape, dense_indexing) + return super().indexing( + index, copy_shape=copy_shape, dense_indexing=dense_indexing + ) def codegen_body( self, name, fuse, could_remove_kernel_buf, kernel_buf_replace_name diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 996ed9c64bb1..3cc67841c7ab 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -217,15 +217,20 @@ class WrapperCodeGen(CodeGen): ) self.prefix.splice( - f""" + """ async_compile.wait(globals()) del async_compile - def call({', '.join(V.graph.graph_inputs.keys())}): + def call(args): """ ) with self.prefix.indent(): + inp_len = len(V.graph.graph_inputs.keys()) + if inp_len != 0: + lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}" + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") for name in V.graph.randomness_seeds: self.prefix.writeline( f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})" @@ -275,6 +280,12 @@ class WrapperCodeGen(CodeGen): def codegen_free(self, buffer): name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, ir.InputBuffer): + self.writeline(f"del {name}") + return + if not self.can_reuse(buffer): return self.freed.add(name) @@ -380,7 +391,7 @@ class WrapperCodeGen(CodeGen): ) output.writeline( - f"print_performance(lambda: call({', '.join(V.graph.graph_inputs.keys())}))" + f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))" ) def define_kernel(self, name: str, kernel: str): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 0f7fcbbf96ac..a4929637f64d 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -5,7 +5,8 @@ import logging from typing import List import functorch -from functorch.compile import make_boxed_compiler, min_cut_rematerialization_partition +from functorch._src.aot_autograd import make_boxed_func +from functorch.compile import min_cut_rematerialization_partition import torch.fx from torch._subclasses.fake_tensor import FakeTensor @@ -86,7 +87,7 @@ def compile_fx_inner( graph_id=None, ): if dynamo_utils.count_calls(gm.graph) == 0: - return gm + return make_boxed_func(gm.forward) _step_logger()( logging.INFO, @@ -137,6 +138,9 @@ def compile_fx_inner( f"{'BACKWARDS' if is_backward else 'FORWARDS'} " f"graph {graph_id}", ) + + # aot autograd needs to know to pass in inputs as a list + result._boxed_call = True return result @@ -159,14 +163,15 @@ def align_inputs(model, inputs, static_input_idxs=()): if len(check_inputs) == 0: return model - def run(*new_inputs): + def run(new_inputs): for i in check_inputs: if new_inputs[i].data_ptr() % ALIGNMENT: - if isinstance(new_inputs, tuple): - new_inputs = list(new_inputs) new_inputs[i] = clone_preserve_strides(new_inputs[i]) - new_inputs = [x.to("cuda") if is_unspec_input(x) else x for x in new_inputs] - return model(*new_inputs) + new_inputs_to_cuda = [ + x.to("cuda") if is_unspec_input(x) else x for x in new_inputs + ] + new_inputs.clear() + return model(new_inputs_to_cuda) return run @@ -179,13 +184,13 @@ def cudagraphify(model, inputs, static_input_idxs=()): compiled_fn = None - def run(*new_inputs): + def run(new_inputs): nonlocal compiled_fn if compiled_fn is None: with dynamo_utils.preserve_rng_state(): compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs) - return compiled_fn(*new_inputs) + return compiled_fn(new_inputs) return run @@ -239,8 +244,9 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()): torch.cuda.synchronize() stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) + # copy static_inputs because it will be cleared in model with torch.cuda.stream(stream): - model(*static_inputs) + model(list(static_inputs)) stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() @@ -248,13 +254,13 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()): # record graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - static_outputs = model(*static_inputs) + static_outputs = model(list(static_inputs)) if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) if config.size_asserts: - def run(*new_inputs): + def run(new_inputs): assert len(static_inputs) == len(new_inputs) for idx, (dst, src, expanded_dims) in enumerate( zip(static_inputs, new_inputs, inps_expanded_dims) @@ -268,6 +274,7 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()): dst = index_expanded_dims(dst, expanded_dims) src = index_expanded_dims(src, expanded_dims) dst.copy_(src) + new_inputs.clear() graph.replay() return static_outputs @@ -276,11 +283,12 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()): idx for idx in range(len(static_inputs)) if idx not in static_input_idxs ] - def run(*new_inputs): + def run(new_inputs): for idx in copy_indices: src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx]) dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx]) dst.copy_(src) + new_inputs.clear() graph.replay() return static_outputs @@ -359,8 +367,8 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor] return aot_autograd( model_, example_inputs_, - fw_compiler=make_boxed_compiler(fw_compiler), - bw_compiler=make_boxed_compiler(bw_compiler), + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, decompositions=select_decomp_table(), partition_fn=functools.partial( min_cut_rematerialization_partition, compiler="inductor" diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index ede2aca75bef..8007725403e9 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -3,8 +3,6 @@ import logging import math import numbers -from functorch._src.aot_autograd import aot_autograd_decompositions - import torch import torch._decomp as decomp from torch import Tensor @@ -98,9 +96,10 @@ decompositions = get_decompositions( aten.tril.default, aten.upsample_bilinear2d.vec, aten.upsample_nearest2d_backward, + aten.softplus, + aten.softplus_backward, ] ) -decompositions.update(aot_autograd_decompositions) def register_decomposition(ops): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 90657b7db1d8..f9982b9e813f 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -154,6 +154,11 @@ def _register_lowering( @functools.wraps(decomp_fn) def wrapped(*args, **kwargs): args = list(args) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = args[0] # Only look at args that are Tensors indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] # kwargs tensors not supported yet @@ -170,14 +175,20 @@ def _register_lowering( dtype = get_promoted_dtype( *promoting_args, type_promotion_kind=type_promotion_kind ) - for i in indices: - args[i] = to_dtype(args[i], dtype) + # sometimes args are an immutable list so we can't mutate them + new_args = [] for i in range(len(args)): - if isinstance(args[i], ir.Constant): - args[i] = ir.Constant( - args[i].value, dtype, args[indices[0]].get_device() + if i in indices: + new_args.append(to_dtype(args[i], dtype)) + elif isinstance(args[i], ir.Constant): + new_args.append( + ir.Constant(args[i].value, dtype, args[indices[0]].get_device()) ) - + else: + new_args.append(args[i]) + args = new_args + if unpacked: + args = [args] if broadcast and indices: for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): args[i] = x @@ -475,12 +486,13 @@ def squeeze(x, dim=None): assert isinstance(x, TensorBox) if dim is None: return TensorBox(SqueezeView.create(x.data)) - - dim = _validate_dim(x, dim, 0) + offset = len(x.get_size()) == 0 + dim = _validate_dim(x, dim, offset) new_shape = list(x.get_size()) - removed = new_shape.pop(dim) - if V.graph.sizevars.maybe_guard_equals(removed, 1): - return view(x, new_shape) + if len(new_shape) > 0: + removed = new_shape.pop(dim) + if V.graph.sizevars.maybe_guard_equals(removed, 1): + return view(x, new_shape) # squeeze does nothing if the size isn't 1 return x diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 88181fb0ce7f..a6349131fe12 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -988,6 +988,11 @@ class Scheduler: node = self.name_to_node[name] if node.can_free(): V.graph.wrapper_code.codegen_free(node.node) + elif name in V.graph.graph_inputs: + storage = V.graph.graph_inputs[name].data + assert storage.is_input_buffer() + V.graph.wrapper_code.codegen_free(storage.data) + self.buffer_names_to_free.clear() def remove_kernel_local_buffers(self): diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py index f6d05cf2f8cf..29e1013eb55f 100644 --- a/torch/_inductor/triton_ops/autotune.py +++ b/torch/_inductor/triton_ops/autotune.py @@ -5,6 +5,7 @@ import json import logging import multiprocessing import os.path +import re import threading from typing import List @@ -110,11 +111,12 @@ class CachingAutotuner(KernelInterface): # set_device(current_device()) # TODO(jansel): is this needed? grid_0, grid_1, grid_2 = grid(grid_meta) bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, - stream, bin.cu_function, None, None, None, - {', '.join(call_args)}) + stream, bin.cu_function, None, None, None, + {', '.join(call_args)}) """.lstrip(), scope, ) + launcher = scope["launcher"] launcher.config = cfg return launcher @@ -160,11 +162,22 @@ class CachingAutotuner(KernelInterface): launcher.config.pre_hook( {**zip(self.arg_names, args), **launcher.config.kwargs} ) - return launcher( - *args, - grid=grid, - stream=stream, - ) + try: + result = launcher( + *args, + grid=grid, + stream=stream, + ) + except TypeError as e: + if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)): + raise RuntimeError( + """Consider updating Triton with +`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`""" + ) + else: + raise e + + return result def hash_configs(configs: List[Config]):