Sync changes from pytorch/torchdynamo, enable tests (#86950)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950
Approved by: https://github.com/Chillee
This commit is contained in:
Jason Ansel
2022-10-14 11:05:28 -07:00
committed by PyTorch MergeBot
parent 78ef40973c
commit 8f71e8de7e
45 changed files with 577 additions and 247 deletions

View File

@ -198,12 +198,22 @@ class OperatorInputsMode(TorchDispatchMode):
def map_to_device(e, device):
return e.to(device) if isinstance(e, torch.Tensor) else e
if isinstance(e, torch.Tensor):
return e.to(device)
elif isinstance(e, torch.device):
return device
elif isinstance(e, str):
if e == "cuda" or e == "cpu":
return device.type
else:
return e
def map_to_dtype(e, dtype):
if isinstance(e, torch.Tensor) and e.is_floating_point():
return e.to(dtype)
elif isinstance(e, torch.dtype):
return dtype
else:
return e

View File

@ -2,7 +2,6 @@
import click
import numpy as np
import torch
import triton
from operator_inp_utils import OperatorInputsLoader
from torch._dynamo.optimizations.backends import cudagraphs_inner
@ -17,7 +16,7 @@ aten = torch.ops.aten
def compute_speedups(
operator, models, example_inputs, repeats, accuracy_checking=False
operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda"
):
expected = models[0](*example_inputs)
if accuracy_checking:
@ -35,10 +34,19 @@ def compute_speedups(
for rep in range(repeats):
# interleave the runs to handle frequency scaling and load changes
for m, model in enumerate(models):
# do_bench() clears L2 cache to hide the latency of CPU launch time
# along with cuda synchronization
median_ms, _, _ = triton.testing.do_bench(lambda: model(*example_inputs))
timings[rep, m] = median_ms
if device == "cuda":
import triton
# do_bench() clears L2 cache to hide the latency of CPU launch time
# along with cuda synchronization
median_ms, _, _ = triton.testing.do_bench(
lambda: model(*example_inputs)
)
timings[rep, m] = median_ms
else:
from torch._inductor.utils import timed
timings[rep, m] = timed(model, example_inputs)
return np.median(timings, axis=0)
@ -64,15 +72,19 @@ def convert_to_jit(gm, gm_args):
def microbenchmark(
operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser
operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser, device
):
gm, gm_args = gen_gm_and_inputs(operator, args, kwargs)
torch.jit._builtins._register_builtin(
torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
)
cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False)
compiled_fn = compile_fx(gm, gm_args)
compiled = [cudagraphs_eager, compiled_fn]
if device == "cuda":
cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False)
compiled_fn = compile_fx(gm, gm_args)
compiled = [cudagraphs_eager, compiled_fn]
else:
compiled_fn = compile_fx(gm, gm_args)
compiled = [gm, compiled_fn]
if measure_nvfuser:
g = convert_to_jit(gm, gm_args)
cudagraphs_jit = cudagraphs_inner(g, gm_args, copy_outputs=False)
@ -80,7 +92,9 @@ def microbenchmark(
if accuracy_checking:
repeats = 1
medians = compute_speedups(operator, compiled, gm_args, repeats, accuracy_checking)
medians = compute_speedups(
operator, compiled, gm_args, repeats, accuracy_checking, device
)
return medians
@ -147,8 +161,9 @@ def skip_operator(operator):
@click.option(
"--measure-nvfuser", help="default we only measure inductor", default=False
)
@click.option("--device", help="cpu or cuda", default="cuda")
def benchmark(
suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser
suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser, device
):
assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}"
if suite == "timm":
@ -176,7 +191,7 @@ def benchmark(
continue
print(f"Running {operator}")
inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype)
inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device)
timings = []
for i in range(min(max_samples, 1000000)):
@ -199,6 +214,7 @@ def benchmark(
accuracy_checking,
repeats,
measure_nvfuser,
device,
)
)
except Exception as e:

View File

@ -237,6 +237,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE:
batch_size = USE_SMALL_BATCH_SIZE[model_name]
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
torch.backends.__allow_nonbracketed_mutation_flag = True
if is_training:
benchmark = benchmark_cls(
test="train", device=device, jit=False, batch_size=batch_size

View File

@ -4,6 +4,7 @@ import functools
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.optimizations.training import is_aot_autograd_safe_to_run
from torch._dynamo.testing import rand_strided
@ -13,7 +14,7 @@ def compiler_safe_fn(gm, example_inputs, is_safe):
return gm.forward
class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
def test_LSTM(self):
# https://github.com/pytorch/torchdynamo/issues/1147
class Repro(torch.nn.Module):
@ -60,6 +61,65 @@ class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
aot_fn(x, y)
self.assertTrue(not is_safe[0])
def test_mutation1(self):
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
getitem = diagonal_chunked_attention_scores[
(
slice(None, None, None),
slice(None, None, None),
slice(None, 256, None),
slice(None, 257, None),
)
]
_stack0[
(
slice(None, None, None),
slice(None, -1, None),
slice(None, None, None),
slice(256, None, None),
)
] = getitem
view = _stack0.view(1, 12, 1024, 513)
return (view,)
x = torch.randn(torch.Size([12, 4, 256, 513]))
y = torch.randn(torch.Size([12, 3, 512, 513]))
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
aot_fn(x, y)
self.assertTrue(not is_safe[0])
def test_negative_testing_mutation(self):
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
getitem = diagonal_chunked_attention_scores[
(
slice(None, None, None),
slice(None, None, None),
slice(None, 256, None),
slice(None, 257, None),
)
]
_stack0 = torch.sin(_stack0)
_stack0[
(
slice(None, None, None),
slice(None, -1, None),
slice(None, None, None),
slice(256, None, None),
)
] = getitem
view = _stack0.view(1, 12, 1024, 513)
return (view,)
x = torch.randn(torch.Size([12, 4, 256, 513]))
y = torch.randn(torch.Size([12, 3, 512, 513]))
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
aot_fn(x, y)
self.assertTrue(is_safe[0])
def test_negative_testing(self):
def fn(x, y):
return torch.sin(x).add_(y)
@ -74,6 +134,6 @@ class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -7,8 +7,10 @@ from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch.testing._internal.common_utils import TEST_WITH_ROCM
def composed(*decs):
@ -43,6 +45,7 @@ def assert_aot_autograd_counter(ok=True):
def patch_all(ok=True):
return composed(
unittest.skipIf(TEST_WITH_ROCM, "ROCm not supported"),
patch("torch._dynamo.config.verify_correctness", True),
assert_aot_autograd_counter(ok),
)
@ -52,7 +55,7 @@ N_ITERS = 5
@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
class TestAotCudagraphs(torch._dynamo.testing.TestCase):
class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
@patch_all()
def test_basic(self):
def model(x, y):
@ -201,6 +204,6 @@ class TestAotCudagraphs(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -7,6 +7,7 @@ import pytest
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch.distributed as dist
from torch import nn
from torch._dynamo import config
@ -43,7 +44,7 @@ def skip_if_no_active_ddp():
@pytest.mark.skip("Module hangs in PyTorch CI")
class TestDistributed(torch._dynamo.testing.TestCase):
class TestDistributed(torch._dynamo.test_case.TestCase):
"""
Test harness initializes dist process group
"""
@ -209,6 +210,63 @@ class TestDistributed(torch._dynamo.testing.TestCase):
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_custom_layer(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
from torch.nn.parallel import DistributedDataParallel as DDP
skip_if_no_active_ddp()
class MyCustomLinear(torch.nn.Module):
def __init__(self):
super(MyCustomLinear, self).__init__()
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
return torch.mm(x, self.weight.t())
class MyLinear(torch.nn.Module):
def __init__(self):
super(MyLinear, self).__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return self.linear(x)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwitch the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
def forward(self, x):
return self.seq(x)
m = MyModule().to(self.device)
inputs = torch.randn((512, 512)).to(self.device)
correct_outputs = m(inputs)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
def test_empty_graph(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()

View File

@ -25,6 +25,6 @@ DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests)
DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests)
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -3,12 +3,13 @@ from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import make_fx
class ExportTests(torch._dynamo.testing.TestCase):
class ExportTests(torch._dynamo.test_case.TestCase):
# TODO(voz): Refactor to a shared test function.
# The tests in this file are a little redundant,
# They all take a func, run it with eager, then export it, then compare
@ -1423,6 +1424,6 @@ class ExportTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -9,6 +9,7 @@ from typing import Any
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch import sub
from torch._dynamo.testing import requires_static_shapes
@ -53,7 +54,7 @@ def inline_unused(x):
return x + 5.6
class FunctionTests(torch._dynamo.testing.TestCase):
class FunctionTests(torch._dynamo.test_case.TestCase):
@make_test
def test_inline_jit_annotations(x):
x = inline_script_if_tracing(x)
@ -670,6 +671,6 @@ class FunctionTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
@ -43,7 +44,7 @@ def reset_name():
_name = 0
class TestGlobals(torch._dynamo.testing.TestCase):
class TestGlobals(torch._dynamo.test_case.TestCase):
def test_store_global_1(self):
def fn(x):
global g_counter
@ -227,6 +228,6 @@ class TestGlobals(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.optimizations.backends import create_backend
@ -23,7 +24,7 @@ class MockModule(torch.nn.Module):
return x
class MinfierTests(torch._dynamo.testing.TestCase):
class MinfierTests(torch._dynamo.test_case.TestCase):
def test_after_dynamo(self):
@create_backend
def bad_dynamo_backend(subgraph):
@ -92,6 +93,6 @@ class MinfierTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -16,6 +16,7 @@ from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.onnx.operators
from torch._dynamo import bytecode_transformation
@ -34,7 +35,7 @@ def my_custom_function(x):
return x + 1
class MiscTests(torch._dynamo.testing.TestCase):
class MiscTests(torch._dynamo.test_case.TestCase):
def test_boolarg(self):
def boolarg(aa, bb, flag):
if flag:
@ -2719,6 +2720,6 @@ class TestTracer(JitTestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -4,6 +4,7 @@ import unittest.mock
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
@ -22,7 +23,7 @@ def maybe_skip(fn):
return fn
class TestHFPretrained(torch._dynamo.testing.TestCase):
class TestHFPretrained(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_pretrained(self):
def fn(a, tmp):
@ -38,7 +39,7 @@ class TestHFPretrained(torch._dynamo.testing.TestCase):
self.assertTrue(same(ref, res))
class TestModelOutput(torch._dynamo.testing.TestCase):
class TestModelOutput(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_mo_create(self):
def fn(a, b):
@ -160,6 +161,6 @@ class TestModelOutput(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
@ -604,7 +605,7 @@ def make_test(fn, expected_ops=None):
return test_fn
class NNModuleTests(torch._dynamo.testing.TestCase):
class NNModuleTests(torch._dynamo.test_case.TestCase):
test_seq = make_test(Seq())
test_basicmodule1 = make_test(BasicModule())
test_basicmodule2 = make_test(BasicModule())
@ -884,6 +885,6 @@ class NNModuleTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -24,6 +24,6 @@ NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests)
NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests)
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import eval_frame
@ -35,7 +36,7 @@ with_debug_nops = eval_frame._optimize_catch_errors(
)
class NopTests(torch._dynamo.testing.TestCase):
class NopTests(torch._dynamo.test_case.TestCase):
@with_debug_nops
def test1(self):
self.assertEqual(fn1(1, 2), -7)
@ -66,6 +67,6 @@ class NopTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -8,6 +8,7 @@ from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.optimizations import backends
from torch._dynamo.optimizations.analysis import has_mutation
from torch._dynamo.optimizations.log_args import conv_args_analysis
@ -64,7 +65,7 @@ class Conv_Bn_Relu(torch.nn.Module):
return self.relu(self.bn(self.conv(x)))
class TestOptimizations(torch._dynamo.testing.TestCase):
class TestOptimizations(torch._dynamo.test_case.TestCase):
def test_inplacifier(self):
gm = torch.fx.symbolic_trace(Seq())
normalize(gm)
@ -183,7 +184,7 @@ class TestOptimizations(torch._dynamo.testing.TestCase):
self.assertEqual(r2.dtype, torch.bfloat16)
class NormalizeIRTests(torch._dynamo.testing.TestCase):
class NormalizeIRTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not has_functorch(), "requires functorch")
def test_inplace_normalize(self):
def fn(a, b):
@ -202,6 +203,6 @@ class NormalizeIRTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -6,6 +6,7 @@ import unittest
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
input = torch.ones([10, 10])
@ -37,7 +38,7 @@ def make_test(optim_cls, exp_frame_cnt=1, closure=None, **kwargs):
return test_fn
class OptimizerTests(torch._dynamo.testing.TestCase):
class OptimizerTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
@ -97,6 +98,6 @@ for opt in optimizers:
setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -4,7 +4,8 @@ from typing import Callable, Dict, List, NamedTuple, Optional
import torch
import torch._dynamo
from torch._dynamo.testing import CompileCounter, same, TestCase
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
"""
This is an example of a pure-python version of autograd implemented by
@ -283,6 +284,4 @@ class TestPythonAutograd(TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
run_tests()

View File

@ -6,10 +6,11 @@ import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
class RecompileUxTests(torch._dynamo.testing.TestCase):
class RecompileUxTests(torch._dynamo.test_case.TestCase):
# TODO(whc) dynamo actualy recompiles one more time than the cache limit
cache_limit = 1

View File

@ -6,6 +6,7 @@ import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
try:
@ -16,7 +17,7 @@ except ImportError:
requires_dill = unittest.skipIf(dill is None, "requires dill")
class ReplayRecordTests(torch._dynamo.testing.TestCase):
class ReplayRecordTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
@ -181,6 +182,6 @@ class ReplayRecordTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -14,6 +14,7 @@ from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch import nn
@ -749,7 +750,7 @@ class TestModule(torch.nn.Module):
return self.inner_fn(tensor.shape, (1, 2, 3))
class ReproTests(torch._dynamo.testing.TestCase):
class ReproTests(torch._dynamo.test_case.TestCase):
def test_do_paste_mask(self):
torch._dynamo.utils.counters.clear()
opt__do_paste_mask = torch._dynamo.optimize(
@ -1712,6 +1713,6 @@ class ReproTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -4,10 +4,11 @@ from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.testing import CompileCounter
class SkipNonTensorTests(torch._dynamo.testing.TestCase):
class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
def test_add_tensor1(self):
def fn(a, b):
return a + b
@ -107,6 +108,6 @@ class SkipNonTensorTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import unsupported
@ -17,7 +18,7 @@ def indirectly_unsupported(a, b):
return unsupported(a, c)
class SubGraphTests(torch._dynamo.testing.TestCase):
class SubGraphTests(torch._dynamo.test_case.TestCase):
def _common(self, fn, frame_count, op_count):
torch._dynamo.reset()
v1 = torch.ones(10)
@ -528,6 +529,6 @@ class SubGraphTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -7,6 +7,7 @@ from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
@ -51,7 +52,7 @@ UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests)
@patch.object(torch._dynamo.config, "specialize_int_float", False)
class UnspecTests(torch._dynamo.testing.TestCase):
class UnspecTests(torch._dynamo.test_case.TestCase):
def test_numpy_correctness(self):
def fn(x, y, z):
xy = [x + y, y, False]
@ -221,6 +222,6 @@ class UnspecTests(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -8,6 +8,7 @@ import torch
import torch._dynamo
import torch._dynamo.config as config
import torch._dynamo.test_case
from torch._dynamo.optimizations import backends
from torch._dynamo.testing import same
@ -77,7 +78,7 @@ def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
return gm
class TestVerifyCorrectness(torch._dynamo.testing.TestCase):
class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
@patch.object(config, "verify_correctness", True)
def test_example_inputs(self):
def fn(a, bc, d):
@ -169,6 +170,6 @@ class TestVerifyCorrectness(torch._dynamo.testing.TestCase):
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -6,6 +6,7 @@ import importlib
import random
import sys
import unittest
import weakref
from unittest.mock import patch
import torch
@ -19,6 +20,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
TestCase as TorchTestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten, tree_unflatten
try:
@ -74,7 +76,6 @@ if torch.cuda.is_available():
pass
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
torch._inductor.config.triton.autotune = False # too slow
@ -242,7 +243,6 @@ def check_model(
# for graph in exp[2]:
# print("Graph", graph)
assert called, "Ran graph without calling compile_fx"
assert type(actual) == type(correct)
correct_flat, correct_spec = tree_flatten(correct)
@ -2004,7 +2004,11 @@ class CommonTemplate:
def test_cat(self):
def fn(a):
tmp = a * 2
return torch.cat((a, a[:, :4] + 1, a + 2), -1), torch.cat((tmp, tmp), 0)
return (
torch.cat((a, a[:, :4] + 1, a + 2), -1),
torch.cat((tmp, tmp), 0),
torch.cat((tmp, tmp.double()), 0),
)
self.common(
fn,
@ -3171,6 +3175,15 @@ class CommonTemplate:
],
)
# issue #1150
def test_dense_mask_index(self):
def fn(x, y):
y = torch.ops.aten.select.int(y, 0, 2)
z = x * y
return z.sum()
self.common(fn, [torch.randn(102400), torch.randn(3)])
def test_new_empty_strided(self):
def fn(a):
return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1]).fill_(123)
@ -3714,10 +3727,10 @@ class CommonTemplate:
)
compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)])
out = compiled(torch.randn(8, 4, device=self.device))
out = compiled([torch.randn(8, 4, device=self.device)])
self.assertEqual(out[0].shape, (16, 2))
out = compiled(torch.randn(12, 4, device=self.device))
out = compiled([torch.randn(12, 4, device=self.device)])
self.assertEqual(out[0].shape, (24, 2))
@requires_cuda()
@ -3744,6 +3757,63 @@ class CommonTemplate:
)
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
@patch.object(config.triton, "mm", "aten")
def test_list_clearing(self):
if self.device == "cpu":
contexts = [contextlib.nullcontext]
else:
contexts = [
contextlib.nullcontext,
lambda: patch.object(config.triton, "cudagraphs", True),
]
for context in contexts:
with context():
inps = [
torch.rand([5, 5]).to(self.device),
torch.rand([5, 5]).to(self.device),
]
inp_refs = [weakref.ref(inp) for inp in inps]
def fn(x, y):
a = x + y
return (a @ a,)
fn_fx = make_fx(fn)(inps[0], inps[1])
fn_compiled = compile_fx_inner(fn_fx, inps)
test_self = self
matmul_seen = False
class TestRefMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
nonlocal inps
nonlocal inp_refs
nonlocal test_self
nonlocal matmul_seen
# by matmul, inputs should be deallocated
if func is aten.mm.out:
matmul_seen = True
test_self.assertEqual(len(inps), 0)
test_self.assertIsNone(inp_refs[0]())
test_self.assertIsNone(inp_refs[1]())
return func(*args, **kwargs)
with TestRefMode():
fn_compiled(inps)
# for some reason, TorchDispatch doesnt capture the
# cuda mm call (even without cudagraphs)
if self.device == "cpu":
self.assertTrue(matmul_seen)
else:
self.assertEqual(len(inps), 0)
if HAS_CPU:
@ -3781,7 +3851,7 @@ if HAS_CPU:
fn_fx = make_fx(fn)(x1, y)
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
fn(x2, y)
fn_compiled(x3, y)
fn_compiled([x3, y])
assert same(x2, x3)
def test_no_op_squeeze(self):
@ -3872,7 +3942,7 @@ if HAS_CUDA:
]
mod = make_fx(forward)(*inps)
compiled = compile_fx_inner(mod, inps)
compiled(*inps)
compiled(inps)
@patch.object(config, "fallback_random", True)
def test_dtype_factory_issue(self):
@ -3888,7 +3958,7 @@ if HAS_CUDA:
mod = make_fx(forward)()
compiled = compile_fx_inner(mod, ())
assert compiled()[0].device.type == "cuda"
assert compiled([])[0].device.type == "cuda"
@patch.object(config.triton, "cudagraphs", True)
def test_expanded_inputs_cudagraphs(self):
@ -3952,6 +4022,7 @@ if HAS_CUDA:
if __name__ == "__main__":
from torch._dynamo.testing import run_tests
from torch._dynamo.test_case import run_tests
run_tests(needs="filelock")
if HAS_CPU or HAS_CUDA:
run_tests(needs="filelock")

View File

@ -33,7 +33,7 @@ try:
from .test_torchinductor import check_model, check_model_cuda
except ImportError:
from test_torchinductor import check_model, check_model_cuda
except (unittest.SkipTest, ImportError) as e:
except (unittest.SkipTest, ImportError, AssertionError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__":
sys.exit(0)
@ -154,6 +154,9 @@ inductor_skips["cuda"] = {
"jiterator_binary": {b8, f16, f32, f64, i32, i64},
"jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64},
"jiterator_unary": {b8, f16, f32, f64, i32, i64},
# Triton bug leads to segfault
"nn.functional.softplus": {f64},
"nn.functional.mish": {f64},
}
inductor_expected_failures_single_sample = defaultdict(dict)
@ -372,24 +375,13 @@ inductor_expected_failures_single_sample["cuda"] = {
inductor_gradient_expected_failures_single_sample = defaultdict(dict)
inductor_gradient_expected_failures_single_sample["cuda"] = {
"amax": {f16, f32, f64},
"amin": {f16, f32, f64},
"asin": {f16},
"cumprod": {f16},
"linalg.vector_norm": {f64, f64},
"linalg.householder_product": {f32},
"linalg.lu": {f32, f64},
"kron": {f16},
"masked.amax": {f16, f32, f64},
"masked.amin": {f16, f32, f64},
"max.reduction_no_dim": {f16, f32, f64},
"median": {f16, f32, f64},
"min.reduction_no_dim": {f16, f32, f64},
"nan_to_num": {f16, f32, f64},
"nanmean": {f16, f32, f64},
"nanmedian": {f16, f32, f64},
"nanquantile": {f32, f64},
"nansum": {f16, f32, f64},
"native_batch_norm": {f16, f32, f64},
"native_layer_norm": {f16, f32, f64},
"nn.functional._scaled_dot_product_attention": {f16},
@ -446,6 +438,7 @@ inductor_override_kwargs = {
"new_empty_strided": {"assert_equal": False},
"randn": {"assert_equal": False},
("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001},
("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002},
"gradient": {"check_gradient": False}, # segfault on check_gradient
# Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error
"linalg.solve_triangular": {"check_gradient": False},
@ -461,6 +454,8 @@ inductor_all_samples = {
"index_copy",
"scatter_reduce.sum",
"select_scatter",
"squeeze",
"unsqueeze",
}

View File

@ -54,8 +54,6 @@ constant_functions = {
torch.onnx.is_in_onnx_export: False,
}
# root folder of the project
base_dir = dirname(dirname(dirname(abspath(__file__))))
# don't specialize on shapes and strides and put shape ops in graph
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
@ -152,6 +150,12 @@ dynamo_import = __name__.replace(".config", "")
# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = dynamo_import.replace("dynamo", "inductor")
# root folder of the project
if "torch." in dynamo_import:
base_dir = dirname(dirname(dirname(abspath(__file__))))
else:
base_dir = dirname(dirname(abspath(__file__)))
class _AccessLimitingConfig(ModuleType):
def __setattr__(self, name, value):

View File

@ -228,7 +228,7 @@ def save_graph_repro(fd, gm, args, compiler_name):
textwrap.dedent(
f"""
compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
compiled(*args)
compiled(args)
"""
)
)
@ -293,7 +293,7 @@ def inductor_fails(fx_g, args, check_str=None):
try:
compile_mod = compile_fx_inner(fx_g, args)
compile_mod(*args)
compile_mod(args)
except Exception as e:
if check_str is not None and check_str not in repr(e):
return False
@ -385,7 +385,7 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str):
orig_graph = copy.deepcopy(gm.graph)
assert config.repro_after in ("dynamo", "aot", None)
def deferred_for_real_inputs(*real_inputs):
def deferred_for_real_inputs(real_inputs):
"""
Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
example_inputs can be fake tensors. We can call compiler_fn (which is
@ -420,14 +420,14 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str):
raise ValueError("Bad accuracy detected")
else:
# Call the compiled function with real inputs
return compiled_fn(*real_inputs)
return compiled_fn(real_inputs)
else:
try:
# Call the compiler_fn - which is either aot_autograd or inductor
# with fake inputs
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
# Call the compiled function with real inputs
return compiled_fn(*real_inputs)
return compiled_fn(real_inputs)
except Exception as e:
if config.repro_level == 1:
dump_compiler_graph_state(
@ -441,6 +441,7 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str):
if config.repro_after == "aot":
compiled_fn = deferred_for_real_inputs
compiled_fn._boxed_call = True
else:
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
@ -453,6 +454,8 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False):
"""
Runs a forward and possibly backward iteration for a given mod and args.
"""
from functorch._src.aot_autograd import make_boxed_func
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
gm = copy.deepcopy(gm)
@ -465,7 +468,14 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False):
if hasattr(gm, "zero_grad"):
gm.zero_grad(True)
out = gm(*args)
# TorchInductor returned callable expects lists. So, boxing the call.
if not hasattr(gm, "_boxed_call"):
orig_named_parameters = gm.named_parameters
gm = make_boxed_func(gm)
gm.named_parameters = orig_named_parameters
out = gm(args)
if only_fwd:
return out
if requires_bwd_pass(out):
@ -775,7 +785,7 @@ def wrap_backend_debug(compiler_fn, compiler_name: str):
else:
try:
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
run_fwd_maybe_bwd(compiled_gm, example_inputs)
except Exception as exc:
log.warning(
"Compiled Fx GraphModule failed with following error. Setting up minifier."
@ -815,7 +825,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name):
try:
compiled_gm = compiler_fn(gm, example_inputs)
run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
run_fwd_maybe_bwd(compiled_gm, example_inputs)
raise ValueError("No issue was detected")
except Exception as exc:
orig_failure = str(exc)

View File

@ -10,7 +10,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._pytree import tree_map
from .. import config
from ..utils import fake_tensors_available
from ..utils import clone_inputs, fake_tensors_available
if fake_tensors_available:
from torch._subclasses import FakeTensorMode # noqa: F401
@ -52,6 +52,14 @@ class ShapeAliasingAndMutationProp(ShapeProp):
n.meta["alias_groups"] = {
self.tensor_alias_group(obj) for obj in self.extract_tensors(result)
}
if (
not n.meta["alias_groups"]
and n.op == "call_function"
and n.target == operator.setitem
):
n.meta["alias_groups"] = {self.tensor_alias_group(tensor_args[0])}
n.meta["mutates_alias_groups"] = {
self.tensor_alias_group(tensor)
for tensor, v1, v2 in zip(tensor_args, input_versions1, input_versions2)
@ -113,6 +121,10 @@ def has_mutation(gm, example_inputs, inputs_only=False):
true, we only check for mutation of inputs"""
# TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way.
# Clone the inputs such that intermediate tensors (not leaf tensors) with
# requires_grad to True are now converted to False to avoid Runtime Error
# like "leaf variable that requires grad is inplace modified"
example_inputs = clone_inputs(example_inputs)
if fake_tensors_available and config.fake_tensor_propagation:
with FakeTensorMode() as fake_mode:
pass

View File

@ -681,9 +681,7 @@ def tvm_compile_inner(
elif tuning_option == "meta_schedule":
from os import path as osp
from tvm.meta_schedule import TuneConfig
from tvm.meta_schedule.database import JSONDatabase
from tvm.meta_schedule.tune import tune_relay
from tvm.contrib.torch import optimize_torch
with tempfile.TemporaryDirectory() as work_dir:
if log_file is not None:
@ -691,22 +689,16 @@ def tvm_compile_inner(
log_file
), "TVM's meta_schedule requires a directory for storing log files."
work_dir = log_file
lib: tvm.runtime.Module = tune_relay(
mod=mod,
params=params,
target=target,
config=TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=trials,
max_trials_global=trials,
),
lib = optimize_torch(
jit_mod,
example_inputs,
max_trials_global=20000,
work_dir=work_dir,
database=JSONDatabase(
osp.join(work_dir, "workload.json"),
osp.join(work_dir, "records.json"),
),
target=target,
max_trials_per_task=64,
)
elif tuning_option is None:
# no autotuning (for debugging)
with tvm.transform.PassContext(opt_level=10):
@ -716,34 +708,41 @@ def tvm_compile_inner(
"This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
"There are three available options including None, auto_scheduler and meta_schedule."
)
if tune_option != "meta_schedule":
m = graph_executor.GraphModule(lib["default"](dev))
m = graph_executor.GraphModule(lib["default"](dev))
def to_torch_tensor(nd_tensor):
"""A helper function to transfer a NDArray to torch.tensor."""
if nd_tensor.dtype == "bool":
# DLPack does not support boolean so it can't be handled by
# torch.utils.dlpack.from_pack. Workaround by going through
# numpy, although this brings additional data copy overhead.
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def to_torch_tensor(nd_tensor):
"""A helper function to transfer a NDArray to torch.tensor."""
if nd_tensor.dtype == "bool":
# DLPack does not support boolean so it can't be handled by
# torch.utils.dlpack.from_pack. Workaround by going through
# numpy, although this brings additional data copy overhead.
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def exec_tvm(*args):
args = [a.contiguous() for a in args]
for idx, arg in enumerate(args, 0):
if arg.dim() != 0:
if arg.requires_grad:
arg = arg.detach()
m.set_input(
f"inp_{idx}",
tvm.nd.array(arg.numpy(), dev),
)
m.run()
return [
to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())
]
def exec_tvm(*args):
args = [a.contiguous() for a in args]
for idx, arg in enumerate(args, 0):
if arg.dim() != 0:
if arg.requires_grad:
arg = arg.detach()
m.set_input(
f"inp_{idx}",
tvm.nd.array(arg.numpy(), dev),
)
m.run()
return [
to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())
]
else:
def exec_tvm(*args):
args = [a.contiguous() for a in args]
return lib(*args)
return exec_tvm
except Exception:
log.exception("tvm error")
return jit_mod # explicit fall back to eager

View File

@ -43,14 +43,14 @@ class DDPOptimizer:
bucket_actual_sizes = []
node_splits = [[]]
for node in reversed(gm.graph.nodes):
if node.op == "output" or node.op == "placeholder":
continue
if bucket_bytes >= self.bucket_bytes_cap:
bucket_actual_sizes.insert(0, bucket_bytes)
bucket_bytes = 0
node_splits.insert(0, [])
if node.op == "output" or node.op == "placeholder":
continue
elif node.op == "call_module":
target = gm.get_submodule(node.target)
params_size_b = sum(
@ -62,6 +62,10 @@ class DDPOptimizer:
)
bucket_bytes += params_size_b
# print(f"accumulated {params_size_b} b from {node}")
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if maybe_param.requires_grad:
bucket_bytes += maybe_param.storage().nbytes()
else:
# TODO(whc) confirm this:
# (e.g. call_method, call_function aren't expected to 'have' parameters)

View File

@ -229,7 +229,12 @@ class OutputGraph(fx.Tracer):
return wrap_name(k)
# create a new unique name
name = re.sub(r"[^a-zA-Z0-9]", "_", "_".join(map(str, names)))
name = "_".join(map(str, names))
# e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
# e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
if not name or not name[0].isalpha():
name = "sub" + name
base = name

View File

@ -0,0 +1,71 @@
import contextlib
import importlib
import sys
from unittest.mock import patch
import torch
import torch.testing
from torch.testing._internal.common_utils import (
IS_WINDOWS,
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
TEST_WITH_TORCHDYNAMO,
TestCase as TorchTestCase,
)
from . import config, reset, utils
def run_tests(needs=()):
from torch.testing._internal.common_utils import run_tests
if (
TEST_WITH_TORCHDYNAMO
or IS_WINDOWS
or TEST_WITH_CROSSREF
or TEST_WITH_ROCM
or sys.version_info >= (3, 11)
):
return # skip testing
if isinstance(needs, str):
needs = (needs,)
for need in needs:
if need == "cuda" and not torch.cuda.is_available():
return
else:
try:
importlib.import_module(need)
except ImportError:
return
run_tests()
class TestCase(TorchTestCase):
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
super().tearDownClass()
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(
patch.object(config, "raise_on_backend_error", True)
)
cls._exit_stack.enter_context(
patch.object(config, "raise_on_ctx_manager_usage", True)
)
def setUp(self):
super().setUp()
reset()
utils.counters.clear()
def tearDown(self):
for k, v in utils.counters.items():
print(k, v.most_common())
reset()
utils.counters.clear()
super().tearDown()

View File

@ -1,19 +1,16 @@
import contextlib
import dis
import functools
import importlib
import logging
import os.path
import sys
import types
import unittest
from unittest.mock import patch
import torch
import torch.testing._internal.common_utils
from torch import fx
from . import config, eval_frame, optimize_assert, reset, utils
from . import config, eval_frame, optimize_assert, reset
from .bytecode_transformation import (
create_instruction,
debug_checks,
@ -29,37 +26,6 @@ three = 3
log = logging.getLogger(__name__)
def run_tests(needs=()):
return # TEMPORARY: disable all tests
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
)
if (
TEST_WITH_TORCHDYNAMO
or IS_WINDOWS
or TEST_WITH_CROSSREF
or sys.version_info >= (3, 11)
):
return # skip testing
if isinstance(needs, str):
needs = (needs,)
for need in needs:
if need == "cuda" and not torch.cuda.is_available():
return
else:
try:
importlib.import_module(need)
except ImportError:
return
run_tests()
def clone_me(x):
if x is None:
return None
@ -229,36 +195,6 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
self.assertEqual(actual.op_count, expected_ops)
class TestCase(torch.testing._internal.common_utils.TestCase):
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
super().tearDownClass()
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack = contextlib.ExitStack()
cls._exit_stack.enter_context(
patch.object(config, "raise_on_backend_error", True)
)
cls._exit_stack.enter_context(
patch.object(config, "raise_on_ctx_manager_usage", True)
)
def setUp(self):
super().setUp()
reset()
utils.counters.clear()
def tearDown(self):
for k, v in utils.counters.items():
print(k, v.most_common())
reset()
utils.counters.clear()
super().tearDown()
def dummy_fx_compile(gm: fx.GraphModule, example_inputs):
return gm.forward

View File

@ -42,6 +42,10 @@ from .constant import ConstantVariable
from .lists import ShapeVariable, SizeVariable
class _missing:
pass
class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""
@ -189,8 +193,9 @@ class TensorVariable(VariableTracker):
elif istype(example_value, int) and proxy.node.target in (
torch.seed,
operator.mod,
torch.distributed.get_rank,
torch.distributed.get_world_size,
# some mac builds are missing torch.distributed.get_rank()
getattr(torch.distributed, "get_rank", _missing),
getattr(torch.distributed, "get_world_size", _missing),
):
proxy.node.meta["example_value"] = example_value
return DynamicShapeVariable(proxy, type(example_value), **options)

View File

@ -654,6 +654,7 @@ class TritonKernel(Kernel):
def indexing(
self,
index: sympy.Expr,
*,
copy_shape=None,
dense_indexing=False,
):
@ -686,9 +687,11 @@ class TritonKernel(Kernel):
mask.append(f"{tree.prefix}mask")
dense_mask.append(f"{tree.prefix}mask")
if (need_dense and not have_dense) or index == 0:
if (need_dense and not have_dense) or isinstance(
index, sympy.core.numbers.Integer
):
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
if index == 0:
if isinstance(index, sympy.core.numbers.Integer):
return index_str, "None"
else:
mask = dense_mask
@ -779,7 +782,7 @@ class TritonKernel(Kernel):
def store(self, name, index, value, mode=None):
var = self.args.output(name)
index, mask = self.indexing(index, value, dense_indexing=True)
index, mask = self.indexing(index, dense_indexing=True)
if mode is None:
line = f"tl.store({var} + ({index}), {value}, {mask})"
elif mode == "atomic_add":
@ -861,7 +864,7 @@ class TritonKernel(Kernel):
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
self.suffix.writeline(f"{result_var} = {var_name}")
self.inside_reduction = False
index, mask = self.indexing(index, result_var)
index, mask = self.indexing(index)
assert "rmask" not in index
self.inside_reduction = True
self.outside_loop_vars.add(result_var)

View File

@ -77,7 +77,9 @@ class TritonTemplateKernel(TritonKernel):
def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=True):
# use dense_indexing for TritonTemplateKernel to avoid map::at error
return super().indexing(index, copy_shape, dense_indexing)
return super().indexing(
index, copy_shape=copy_shape, dense_indexing=dense_indexing
)
def codegen_body(
self, name, fuse, could_remove_kernel_buf, kernel_buf_replace_name

View File

@ -217,15 +217,20 @@ class WrapperCodeGen(CodeGen):
)
self.prefix.splice(
f"""
"""
async_compile.wait(globals())
del async_compile
def call({', '.join(V.graph.graph_inputs.keys())}):
def call(args):
"""
)
with self.prefix.indent():
inp_len = len(V.graph.graph_inputs.keys())
if inp_len != 0:
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
self.prefix.writeline(f"{lhs} = args")
self.prefix.writeline("args.clear()")
for name in V.graph.randomness_seeds:
self.prefix.writeline(
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
@ -275,6 +280,12 @@ class WrapperCodeGen(CodeGen):
def codegen_free(self, buffer):
name = buffer.get_name()
# can be freed but not reused
if isinstance(buffer, ir.InputBuffer):
self.writeline(f"del {name}")
return
if not self.can_reuse(buffer):
return
self.freed.add(name)
@ -380,7 +391,7 @@ class WrapperCodeGen(CodeGen):
)
output.writeline(
f"print_performance(lambda: call({', '.join(V.graph.graph_inputs.keys())}))"
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
)
def define_kernel(self, name: str, kernel: str):

View File

@ -5,7 +5,8 @@ import logging
from typing import List
import functorch
from functorch.compile import make_boxed_compiler, min_cut_rematerialization_partition
from functorch._src.aot_autograd import make_boxed_func
from functorch.compile import min_cut_rematerialization_partition
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
@ -86,7 +87,7 @@ def compile_fx_inner(
graph_id=None,
):
if dynamo_utils.count_calls(gm.graph) == 0:
return gm
return make_boxed_func(gm.forward)
_step_logger()(
logging.INFO,
@ -137,6 +138,9 @@ def compile_fx_inner(
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
f"graph {graph_id}",
)
# aot autograd needs to know to pass in inputs as a list
result._boxed_call = True
return result
@ -159,14 +163,15 @@ def align_inputs(model, inputs, static_input_idxs=()):
if len(check_inputs) == 0:
return model
def run(*new_inputs):
def run(new_inputs):
for i in check_inputs:
if new_inputs[i].data_ptr() % ALIGNMENT:
if isinstance(new_inputs, tuple):
new_inputs = list(new_inputs)
new_inputs[i] = clone_preserve_strides(new_inputs[i])
new_inputs = [x.to("cuda") if is_unspec_input(x) else x for x in new_inputs]
return model(*new_inputs)
new_inputs_to_cuda = [
x.to("cuda") if is_unspec_input(x) else x for x in new_inputs
]
new_inputs.clear()
return model(new_inputs_to_cuda)
return run
@ -179,13 +184,13 @@ def cudagraphify(model, inputs, static_input_idxs=()):
compiled_fn = None
def run(*new_inputs):
def run(new_inputs):
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.preserve_rng_state():
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
return compiled_fn(*new_inputs)
return compiled_fn(new_inputs)
return run
@ -239,8 +244,9 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
# copy static_inputs because it will be cleared in model
with torch.cuda.stream(stream):
model(*static_inputs)
model(list(static_inputs))
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
@ -248,13 +254,13 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
static_outputs = model(list(static_inputs))
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
if config.size_asserts:
def run(*new_inputs):
def run(new_inputs):
assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
@ -268,6 +274,7 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
dst = index_expanded_dims(dst, expanded_dims)
src = index_expanded_dims(src, expanded_dims)
dst.copy_(src)
new_inputs.clear()
graph.replay()
return static_outputs
@ -276,11 +283,12 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
]
def run(*new_inputs):
def run(new_inputs):
for idx in copy_indices:
src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx])
dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx])
dst.copy_(src)
new_inputs.clear()
graph.replay()
return static_outputs
@ -359,8 +367,8 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]
return aot_autograd(
model_,
example_inputs_,
fw_compiler=make_boxed_compiler(fw_compiler),
bw_compiler=make_boxed_compiler(bw_compiler),
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
decompositions=select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"

View File

@ -3,8 +3,6 @@ import logging
import math
import numbers
from functorch._src.aot_autograd import aot_autograd_decompositions
import torch
import torch._decomp as decomp
from torch import Tensor
@ -98,9 +96,10 @@ decompositions = get_decompositions(
aten.tril.default,
aten.upsample_bilinear2d.vec,
aten.upsample_nearest2d_backward,
aten.softplus,
aten.softplus_backward,
]
)
decompositions.update(aot_autograd_decompositions)
def register_decomposition(ops):

View File

@ -154,6 +154,11 @@ def _register_lowering(
@functools.wraps(decomp_fn)
def wrapped(*args, **kwargs):
args = list(args)
unpacked = False
# TODO maybe we need to use pytrees here
if len(args) == 1 and isinstance(args[0], (list, tuple)):
unpacked = True
args = args[0]
# Only look at args that are Tensors
indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
# kwargs tensors not supported yet
@ -170,14 +175,20 @@ def _register_lowering(
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind
)
for i in indices:
args[i] = to_dtype(args[i], dtype)
# sometimes args are an immutable list so we can't mutate them
new_args = []
for i in range(len(args)):
if isinstance(args[i], ir.Constant):
args[i] = ir.Constant(
args[i].value, dtype, args[indices[0]].get_device()
if i in indices:
new_args.append(to_dtype(args[i], dtype))
elif isinstance(args[i], ir.Constant):
new_args.append(
ir.Constant(args[i].value, dtype, args[indices[0]].get_device())
)
else:
new_args.append(args[i])
args = new_args
if unpacked:
args = [args]
if broadcast and indices:
for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
args[i] = x
@ -475,12 +486,13 @@ def squeeze(x, dim=None):
assert isinstance(x, TensorBox)
if dim is None:
return TensorBox(SqueezeView.create(x.data))
dim = _validate_dim(x, dim, 0)
offset = len(x.get_size()) == 0
dim = _validate_dim(x, dim, offset)
new_shape = list(x.get_size())
removed = new_shape.pop(dim)
if V.graph.sizevars.maybe_guard_equals(removed, 1):
return view(x, new_shape)
if len(new_shape) > 0:
removed = new_shape.pop(dim)
if V.graph.sizevars.maybe_guard_equals(removed, 1):
return view(x, new_shape)
# squeeze does nothing if the size isn't 1
return x

View File

@ -988,6 +988,11 @@ class Scheduler:
node = self.name_to_node[name]
if node.can_free():
V.graph.wrapper_code.codegen_free(node.node)
elif name in V.graph.graph_inputs:
storage = V.graph.graph_inputs[name].data
assert storage.is_input_buffer()
V.graph.wrapper_code.codegen_free(storage.data)
self.buffer_names_to_free.clear()
def remove_kernel_local_buffers(self):

View File

@ -5,6 +5,7 @@ import json
import logging
import multiprocessing
import os.path
import re
import threading
from typing import List
@ -110,11 +111,12 @@ class CachingAutotuner(KernelInterface):
# set_device(current_device()) # TODO(jansel): is this needed?
grid_0, grid_1, grid_2 = grid(grid_meta)
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
""".lstrip(),
scope,
)
launcher = scope["launcher"]
launcher.config = cfg
return launcher
@ -160,11 +162,22 @@ class CachingAutotuner(KernelInterface):
launcher.config.pre_hook(
{**zip(self.arg_names, args), **launcher.config.kwargs}
)
return launcher(
*args,
grid=grid,
stream=stream,
)
try:
result = launcher(
*args,
grid=grid,
stream=stream,
)
except TypeError as e:
if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)):
raise RuntimeError(
"""Consider updating Triton with
`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
)
else:
raise e
return result
def hash_configs(configs: List[Config]):