mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Sync changes from pytorch/torchdynamo, enable tests (#86950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
78ef40973c
commit
8f71e8de7e
@ -198,12 +198,22 @@ class OperatorInputsMode(TorchDispatchMode):
|
||||
|
||||
|
||||
def map_to_device(e, device):
|
||||
return e.to(device) if isinstance(e, torch.Tensor) else e
|
||||
if isinstance(e, torch.Tensor):
|
||||
return e.to(device)
|
||||
elif isinstance(e, torch.device):
|
||||
return device
|
||||
elif isinstance(e, str):
|
||||
if e == "cuda" or e == "cpu":
|
||||
return device.type
|
||||
else:
|
||||
return e
|
||||
|
||||
|
||||
def map_to_dtype(e, dtype):
|
||||
if isinstance(e, torch.Tensor) and e.is_floating_point():
|
||||
return e.to(dtype)
|
||||
elif isinstance(e, torch.dtype):
|
||||
return dtype
|
||||
else:
|
||||
return e
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
import click
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
from operator_inp_utils import OperatorInputsLoader
|
||||
|
||||
from torch._dynamo.optimizations.backends import cudagraphs_inner
|
||||
@ -17,7 +16,7 @@ aten = torch.ops.aten
|
||||
|
||||
|
||||
def compute_speedups(
|
||||
operator, models, example_inputs, repeats, accuracy_checking=False
|
||||
operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda"
|
||||
):
|
||||
expected = models[0](*example_inputs)
|
||||
if accuracy_checking:
|
||||
@ -35,10 +34,19 @@ def compute_speedups(
|
||||
for rep in range(repeats):
|
||||
# interleave the runs to handle frequency scaling and load changes
|
||||
for m, model in enumerate(models):
|
||||
# do_bench() clears L2 cache to hide the latency of CPU launch time
|
||||
# along with cuda synchronization
|
||||
median_ms, _, _ = triton.testing.do_bench(lambda: model(*example_inputs))
|
||||
timings[rep, m] = median_ms
|
||||
if device == "cuda":
|
||||
import triton
|
||||
|
||||
# do_bench() clears L2 cache to hide the latency of CPU launch time
|
||||
# along with cuda synchronization
|
||||
median_ms, _, _ = triton.testing.do_bench(
|
||||
lambda: model(*example_inputs)
|
||||
)
|
||||
timings[rep, m] = median_ms
|
||||
else:
|
||||
from torch._inductor.utils import timed
|
||||
|
||||
timings[rep, m] = timed(model, example_inputs)
|
||||
return np.median(timings, axis=0)
|
||||
|
||||
|
||||
@ -64,15 +72,19 @@ def convert_to_jit(gm, gm_args):
|
||||
|
||||
|
||||
def microbenchmark(
|
||||
operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser
|
||||
operator, args, kwargs, dtype, accuracy_checking, repeats, measure_nvfuser, device
|
||||
):
|
||||
gm, gm_args = gen_gm_and_inputs(operator, args, kwargs)
|
||||
torch.jit._builtins._register_builtin(
|
||||
torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
|
||||
)
|
||||
cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False)
|
||||
compiled_fn = compile_fx(gm, gm_args)
|
||||
compiled = [cudagraphs_eager, compiled_fn]
|
||||
if device == "cuda":
|
||||
cudagraphs_eager = cudagraphs_inner(gm, gm_args, copy_outputs=False)
|
||||
compiled_fn = compile_fx(gm, gm_args)
|
||||
compiled = [cudagraphs_eager, compiled_fn]
|
||||
else:
|
||||
compiled_fn = compile_fx(gm, gm_args)
|
||||
compiled = [gm, compiled_fn]
|
||||
if measure_nvfuser:
|
||||
g = convert_to_jit(gm, gm_args)
|
||||
cudagraphs_jit = cudagraphs_inner(g, gm_args, copy_outputs=False)
|
||||
@ -80,7 +92,9 @@ def microbenchmark(
|
||||
if accuracy_checking:
|
||||
repeats = 1
|
||||
|
||||
medians = compute_speedups(operator, compiled, gm_args, repeats, accuracy_checking)
|
||||
medians = compute_speedups(
|
||||
operator, compiled, gm_args, repeats, accuracy_checking, device
|
||||
)
|
||||
return medians
|
||||
|
||||
|
||||
@ -147,8 +161,9 @@ def skip_operator(operator):
|
||||
@click.option(
|
||||
"--measure-nvfuser", help="default we only measure inductor", default=False
|
||||
)
|
||||
@click.option("--device", help="cpu or cuda", default="cuda")
|
||||
def benchmark(
|
||||
suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser
|
||||
suite, op, dtype, max_samples, accuracy_checking, repeats, measure_nvfuser, device
|
||||
):
|
||||
assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}"
|
||||
if suite == "timm":
|
||||
@ -176,7 +191,7 @@ def benchmark(
|
||||
continue
|
||||
|
||||
print(f"Running {operator}")
|
||||
inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype)
|
||||
inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device)
|
||||
timings = []
|
||||
|
||||
for i in range(min(max_samples, 1000000)):
|
||||
@ -199,6 +214,7 @@ def benchmark(
|
||||
accuracy_checking,
|
||||
repeats,
|
||||
measure_nvfuser,
|
||||
device,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -237,6 +237,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE:
|
||||
batch_size = USE_SMALL_BATCH_SIZE[model_name]
|
||||
|
||||
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
|
||||
torch.backends.__allow_nonbracketed_mutation_flag = True
|
||||
if is_training:
|
||||
benchmark = benchmark_cls(
|
||||
test="train", device=device, jit=False, batch_size=batch_size
|
||||
|
@ -4,6 +4,7 @@ import functools
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.optimizations.training import is_aot_autograd_safe_to_run
|
||||
from torch._dynamo.testing import rand_strided
|
||||
|
||||
@ -13,7 +14,7 @@ def compiler_safe_fn(gm, example_inputs, is_safe):
|
||||
return gm.forward
|
||||
|
||||
|
||||
class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
|
||||
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
|
||||
def test_LSTM(self):
|
||||
# https://github.com/pytorch/torchdynamo/issues/1147
|
||||
class Repro(torch.nn.Module):
|
||||
@ -60,6 +61,65 @@ class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
|
||||
aot_fn(x, y)
|
||||
self.assertTrue(not is_safe[0])
|
||||
|
||||
def test_mutation1(self):
|
||||
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
|
||||
getitem = diagonal_chunked_attention_scores[
|
||||
(
|
||||
slice(None, None, None),
|
||||
slice(None, None, None),
|
||||
slice(None, 256, None),
|
||||
slice(None, 257, None),
|
||||
)
|
||||
]
|
||||
_stack0[
|
||||
(
|
||||
slice(None, None, None),
|
||||
slice(None, -1, None),
|
||||
slice(None, None, None),
|
||||
slice(256, None, None),
|
||||
)
|
||||
] = getitem
|
||||
view = _stack0.view(1, 12, 1024, 513)
|
||||
return (view,)
|
||||
|
||||
x = torch.randn(torch.Size([12, 4, 256, 513]))
|
||||
y = torch.randn(torch.Size([12, 3, 512, 513]))
|
||||
is_safe = [True]
|
||||
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
|
||||
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
|
||||
aot_fn(x, y)
|
||||
self.assertTrue(not is_safe[0])
|
||||
|
||||
def test_negative_testing_mutation(self):
|
||||
def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor):
|
||||
getitem = diagonal_chunked_attention_scores[
|
||||
(
|
||||
slice(None, None, None),
|
||||
slice(None, None, None),
|
||||
slice(None, 256, None),
|
||||
slice(None, 257, None),
|
||||
)
|
||||
]
|
||||
_stack0 = torch.sin(_stack0)
|
||||
_stack0[
|
||||
(
|
||||
slice(None, None, None),
|
||||
slice(None, -1, None),
|
||||
slice(None, None, None),
|
||||
slice(256, None, None),
|
||||
)
|
||||
] = getitem
|
||||
view = _stack0.view(1, 12, 1024, 513)
|
||||
return (view,)
|
||||
|
||||
x = torch.randn(torch.Size([12, 4, 256, 513]))
|
||||
y = torch.randn(torch.Size([12, 3, 512, 513]))
|
||||
is_safe = [True]
|
||||
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
|
||||
aot_fn = torch._dynamo.optimize(compiler_fn)(fn)
|
||||
aot_fn(x, y)
|
||||
self.assertTrue(is_safe[0])
|
||||
|
||||
def test_negative_testing(self):
|
||||
def fn(x, y):
|
||||
return torch.sin(x).add_(y)
|
||||
@ -74,6 +134,6 @@ class AotAutogradFallbackTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -7,8 +7,10 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ROCM
|
||||
|
||||
|
||||
def composed(*decs):
|
||||
@ -43,6 +45,7 @@ def assert_aot_autograd_counter(ok=True):
|
||||
|
||||
def patch_all(ok=True):
|
||||
return composed(
|
||||
unittest.skipIf(TEST_WITH_ROCM, "ROCm not supported"),
|
||||
patch("torch._dynamo.config.verify_correctness", True),
|
||||
assert_aot_autograd_counter(ok),
|
||||
)
|
||||
@ -52,7 +55,7 @@ N_ITERS = 5
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
|
||||
class TestAotCudagraphs(torch._dynamo.testing.TestCase):
|
||||
class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
@patch_all()
|
||||
def test_basic(self):
|
||||
def model(x, y):
|
||||
@ -201,6 +204,6 @@ class TestAotCudagraphs(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch._dynamo import config
|
||||
@ -43,7 +44,7 @@ def skip_if_no_active_ddp():
|
||||
|
||||
|
||||
@pytest.mark.skip("Module hangs in PyTorch CI")
|
||||
class TestDistributed(torch._dynamo.testing.TestCase):
|
||||
class TestDistributed(torch._dynamo.test_case.TestCase):
|
||||
"""
|
||||
Test harness initializes dist process group
|
||||
"""
|
||||
@ -209,6 +210,63 @@ class TestDistributed(torch._dynamo.testing.TestCase):
|
||||
opt_outputs.sum().backward()
|
||||
self.assertTrue(same(correct_outputs, opt_outputs))
|
||||
|
||||
@patch.object(config, "optimize_ddp", True)
|
||||
def test_custom_layer(self):
|
||||
"""
|
||||
Just ensures that the appropriate number of splits happen (based on
|
||||
bucket size and model parameters) - verifies the number of times
|
||||
the user-provided compiler is called by the DDPOptimizer which is
|
||||
doing the graph splitting
|
||||
"""
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
skip_if_no_active_ddp()
|
||||
|
||||
class MyCustomLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyCustomLinear, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(512, 512))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.mm(x, self.weight.t())
|
||||
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyLinear, self).__init__()
|
||||
self.linear = torch.nn.Linear(512, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
mods = [
|
||||
(MyLinear(), torch.nn.ReLU()),
|
||||
# sandwitch the custom in the middle so it comes before and after
|
||||
(MyCustomLinear(), torch.nn.ReLU()),
|
||||
(MyLinear(), torch.nn.ReLU()),
|
||||
]
|
||||
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
|
||||
|
||||
def forward(self, x):
|
||||
return self.seq(x)
|
||||
|
||||
m = MyModule().to(self.device)
|
||||
inputs = torch.randn((512, 512)).to(self.device)
|
||||
correct_outputs = m(inputs)
|
||||
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
|
||||
|
||||
check_splits_compiler = CheckSplitsCompiler()
|
||||
|
||||
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
|
||||
def opt_fn(inputs):
|
||||
return ddp_m(inputs)
|
||||
|
||||
opt_outputs = opt_fn(inputs)
|
||||
self.assertTrue(same(correct_outputs, opt_outputs))
|
||||
self.assertEqual(check_splits_compiler.compiler_called, 3)
|
||||
|
||||
def test_empty_graph(self):
|
||||
def fn():
|
||||
get_world_size = torch.distributed.distributed_c10d.get_world_size()
|
||||
|
@ -25,6 +25,6 @@ DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests)
|
||||
DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -3,12 +3,13 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
|
||||
class ExportTests(torch._dynamo.testing.TestCase):
|
||||
class ExportTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO(voz): Refactor to a shared test function.
|
||||
# The tests in this file are a little redundant,
|
||||
# They all take a func, run it with eager, then export it, then compare
|
||||
@ -1423,6 +1424,6 @@ class ExportTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -9,6 +9,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch import sub
|
||||
from torch._dynamo.testing import requires_static_shapes
|
||||
@ -53,7 +54,7 @@ def inline_unused(x):
|
||||
return x + 5.6
|
||||
|
||||
|
||||
class FunctionTests(torch._dynamo.testing.TestCase):
|
||||
class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
@make_test
|
||||
def test_inline_jit_annotations(x):
|
||||
x = inline_script_if_tracing(x)
|
||||
@ -670,6 +671,6 @@ class FunctionTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
|
||||
@ -43,7 +44,7 @@ def reset_name():
|
||||
_name = 0
|
||||
|
||||
|
||||
class TestGlobals(torch._dynamo.testing.TestCase):
|
||||
class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
def test_store_global_1(self):
|
||||
def fn(x):
|
||||
global g_counter
|
||||
@ -227,6 +228,6 @@ class TestGlobals(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -6,6 +6,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.optimizations.backends import create_backend
|
||||
|
||||
@ -23,7 +24,7 @@ class MockModule(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class MinfierTests(torch._dynamo.testing.TestCase):
|
||||
class MinfierTests(torch._dynamo.test_case.TestCase):
|
||||
def test_after_dynamo(self):
|
||||
@create_backend
|
||||
def bad_dynamo_backend(subgraph):
|
||||
@ -92,6 +93,6 @@ class MinfierTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -16,6 +16,7 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.onnx.operators
|
||||
from torch._dynamo import bytecode_transformation
|
||||
@ -34,7 +35,7 @@ def my_custom_function(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
class MiscTests(torch._dynamo.testing.TestCase):
|
||||
class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
def test_boolarg(self):
|
||||
def boolarg(aa, bb, flag):
|
||||
if flag:
|
||||
@ -2719,6 +2720,6 @@ class TestTracer(JitTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -4,6 +4,7 @@ import unittest.mock
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
|
||||
@ -22,7 +23,7 @@ def maybe_skip(fn):
|
||||
return fn
|
||||
|
||||
|
||||
class TestHFPretrained(torch._dynamo.testing.TestCase):
|
||||
class TestHFPretrained(torch._dynamo.test_case.TestCase):
|
||||
@maybe_skip
|
||||
def test_pretrained(self):
|
||||
def fn(a, tmp):
|
||||
@ -38,7 +39,7 @@ class TestHFPretrained(torch._dynamo.testing.TestCase):
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
|
||||
class TestModelOutput(torch._dynamo.testing.TestCase):
|
||||
class TestModelOutput(torch._dynamo.test_case.TestCase):
|
||||
@maybe_skip
|
||||
def test_mo_create(self):
|
||||
def fn(a, b):
|
||||
@ -160,6 +161,6 @@ class TestModelOutput(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.eval_frame import unsupported
|
||||
from torch._dynamo.mutation_guard import GenerationTracker
|
||||
@ -604,7 +605,7 @@ def make_test(fn, expected_ops=None):
|
||||
return test_fn
|
||||
|
||||
|
||||
class NNModuleTests(torch._dynamo.testing.TestCase):
|
||||
class NNModuleTests(torch._dynamo.test_case.TestCase):
|
||||
test_seq = make_test(Seq())
|
||||
test_basicmodule1 = make_test(BasicModule())
|
||||
test_basicmodule2 = make_test(BasicModule())
|
||||
@ -884,6 +885,6 @@ class NNModuleTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -24,6 +24,6 @@ NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests)
|
||||
NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import eval_frame
|
||||
|
||||
@ -35,7 +36,7 @@ with_debug_nops = eval_frame._optimize_catch_errors(
|
||||
)
|
||||
|
||||
|
||||
class NopTests(torch._dynamo.testing.TestCase):
|
||||
class NopTests(torch._dynamo.test_case.TestCase):
|
||||
@with_debug_nops
|
||||
def test1(self):
|
||||
self.assertEqual(fn1(1, 2), -7)
|
||||
@ -66,6 +67,6 @@ class NopTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -8,6 +8,7 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.optimizations import backends
|
||||
from torch._dynamo.optimizations.analysis import has_mutation
|
||||
from torch._dynamo.optimizations.log_args import conv_args_analysis
|
||||
@ -64,7 +65,7 @@ class Conv_Bn_Relu(torch.nn.Module):
|
||||
return self.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class TestOptimizations(torch._dynamo.testing.TestCase):
|
||||
class TestOptimizations(torch._dynamo.test_case.TestCase):
|
||||
def test_inplacifier(self):
|
||||
gm = torch.fx.symbolic_trace(Seq())
|
||||
normalize(gm)
|
||||
@ -183,7 +184,7 @@ class TestOptimizations(torch._dynamo.testing.TestCase):
|
||||
self.assertEqual(r2.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
class NormalizeIRTests(torch._dynamo.testing.TestCase):
|
||||
class NormalizeIRTests(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(not has_functorch(), "requires functorch")
|
||||
def test_inplace_normalize(self):
|
||||
def fn(a, b):
|
||||
@ -202,6 +203,6 @@ class NormalizeIRTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -6,6 +6,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
|
||||
input = torch.ones([10, 10])
|
||||
@ -37,7 +38,7 @@ def make_test(optim_cls, exp_frame_cnt=1, closure=None, **kwargs):
|
||||
return test_fn
|
||||
|
||||
|
||||
class OptimizerTests(torch._dynamo.testing.TestCase):
|
||||
class OptimizerTests(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
@ -97,6 +98,6 @@ for opt in optimizers:
|
||||
setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -4,7 +4,8 @@ from typing import Callable, Dict, List, NamedTuple, Optional
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
from torch._dynamo.testing import CompileCounter, same, TestCase
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@ -283,6 +284,4 @@ class TestPythonAutograd(TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -6,10 +6,11 @@ import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
|
||||
|
||||
class RecompileUxTests(torch._dynamo.testing.TestCase):
|
||||
class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO(whc) dynamo actualy recompiles one more time than the cache limit
|
||||
cache_limit = 1
|
||||
|
||||
|
@ -6,6 +6,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
|
||||
try:
|
||||
@ -16,7 +17,7 @@ except ImportError:
|
||||
requires_dill = unittest.skipIf(dill is None, "requires dill")
|
||||
|
||||
|
||||
class ReplayRecordTests(torch._dynamo.testing.TestCase):
|
||||
class ReplayRecordTests(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
@ -181,6 +182,6 @@ class ReplayRecordTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -14,6 +14,7 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch._dynamo.utils
|
||||
from torch import nn
|
||||
@ -749,7 +750,7 @@ class TestModule(torch.nn.Module):
|
||||
return self.inner_fn(tensor.shape, (1, 2, 3))
|
||||
|
||||
|
||||
class ReproTests(torch._dynamo.testing.TestCase):
|
||||
class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
def test_do_paste_mask(self):
|
||||
torch._dynamo.utils.counters.clear()
|
||||
opt__do_paste_mask = torch._dynamo.optimize(
|
||||
@ -1712,6 +1713,6 @@ class ReproTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -4,10 +4,11 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
|
||||
|
||||
class SkipNonTensorTests(torch._dynamo.testing.TestCase):
|
||||
class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_add_tensor1(self):
|
||||
def fn(a, b):
|
||||
return a + b
|
||||
@ -107,6 +108,6 @@ class SkipNonTensorTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import unsupported
|
||||
@ -17,7 +18,7 @@ def indirectly_unsupported(a, b):
|
||||
return unsupported(a, c)
|
||||
|
||||
|
||||
class SubGraphTests(torch._dynamo.testing.TestCase):
|
||||
class SubGraphTests(torch._dynamo.test_case.TestCase):
|
||||
def _common(self, fn, frame_count, op_count):
|
||||
torch._dynamo.reset()
|
||||
v1 = torch.ones(10)
|
||||
@ -528,6 +529,6 @@ class SubGraphTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -7,6 +7,7 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
|
||||
@ -51,7 +52,7 @@ UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests)
|
||||
|
||||
|
||||
@patch.object(torch._dynamo.config, "specialize_int_float", False)
|
||||
class UnspecTests(torch._dynamo.testing.TestCase):
|
||||
class UnspecTests(torch._dynamo.test_case.TestCase):
|
||||
def test_numpy_correctness(self):
|
||||
def fn(x, y, z):
|
||||
xy = [x + y, y, False]
|
||||
@ -221,6 +222,6 @@ class UnspecTests(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config as config
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.optimizations import backends
|
||||
from torch._dynamo.testing import same
|
||||
|
||||
@ -77,7 +78,7 @@ def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
return gm
|
||||
|
||||
|
||||
class TestVerifyCorrectness(torch._dynamo.testing.TestCase):
|
||||
class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
||||
@patch.object(config, "verify_correctness", True)
|
||||
def test_example_inputs(self):
|
||||
def fn(a, bc, d):
|
||||
@ -169,6 +170,6 @@ class TestVerifyCorrectness(torch._dynamo.testing.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
@ -6,6 +6,7 @@ import importlib
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
import weakref
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -19,6 +20,7 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ASAN,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
try:
|
||||
@ -74,7 +76,6 @@ if torch.cuda.is_available():
|
||||
pass
|
||||
|
||||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
||||
|
||||
torch._inductor.config.triton.autotune = False # too slow
|
||||
|
||||
|
||||
@ -242,7 +243,6 @@ def check_model(
|
||||
# for graph in exp[2]:
|
||||
# print("Graph", graph)
|
||||
assert called, "Ran graph without calling compile_fx"
|
||||
|
||||
assert type(actual) == type(correct)
|
||||
|
||||
correct_flat, correct_spec = tree_flatten(correct)
|
||||
@ -2004,7 +2004,11 @@ class CommonTemplate:
|
||||
def test_cat(self):
|
||||
def fn(a):
|
||||
tmp = a * 2
|
||||
return torch.cat((a, a[:, :4] + 1, a + 2), -1), torch.cat((tmp, tmp), 0)
|
||||
return (
|
||||
torch.cat((a, a[:, :4] + 1, a + 2), -1),
|
||||
torch.cat((tmp, tmp), 0),
|
||||
torch.cat((tmp, tmp.double()), 0),
|
||||
)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
@ -3171,6 +3175,15 @@ class CommonTemplate:
|
||||
],
|
||||
)
|
||||
|
||||
# issue #1150
|
||||
def test_dense_mask_index(self):
|
||||
def fn(x, y):
|
||||
y = torch.ops.aten.select.int(y, 0, 2)
|
||||
z = x * y
|
||||
return z.sum()
|
||||
|
||||
self.common(fn, [torch.randn(102400), torch.randn(3)])
|
||||
|
||||
def test_new_empty_strided(self):
|
||||
def fn(a):
|
||||
return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1]).fill_(123)
|
||||
@ -3714,10 +3727,10 @@ class CommonTemplate:
|
||||
)
|
||||
compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)])
|
||||
|
||||
out = compiled(torch.randn(8, 4, device=self.device))
|
||||
out = compiled([torch.randn(8, 4, device=self.device)])
|
||||
self.assertEqual(out[0].shape, (16, 2))
|
||||
|
||||
out = compiled(torch.randn(12, 4, device=self.device))
|
||||
out = compiled([torch.randn(12, 4, device=self.device)])
|
||||
self.assertEqual(out[0].shape, (24, 2))
|
||||
|
||||
@requires_cuda()
|
||||
@ -3744,6 +3757,63 @@ class CommonTemplate:
|
||||
)
|
||||
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
|
||||
|
||||
@patch.object(config.triton, "mm", "aten")
|
||||
def test_list_clearing(self):
|
||||
|
||||
if self.device == "cpu":
|
||||
contexts = [contextlib.nullcontext]
|
||||
else:
|
||||
contexts = [
|
||||
contextlib.nullcontext,
|
||||
lambda: patch.object(config.triton, "cudagraphs", True),
|
||||
]
|
||||
|
||||
for context in contexts:
|
||||
with context():
|
||||
inps = [
|
||||
torch.rand([5, 5]).to(self.device),
|
||||
torch.rand([5, 5]).to(self.device),
|
||||
]
|
||||
inp_refs = [weakref.ref(inp) for inp in inps]
|
||||
|
||||
def fn(x, y):
|
||||
a = x + y
|
||||
return (a @ a,)
|
||||
|
||||
fn_fx = make_fx(fn)(inps[0], inps[1])
|
||||
fn_compiled = compile_fx_inner(fn_fx, inps)
|
||||
|
||||
test_self = self
|
||||
matmul_seen = False
|
||||
|
||||
class TestRefMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
nonlocal inps
|
||||
nonlocal inp_refs
|
||||
nonlocal test_self
|
||||
nonlocal matmul_seen
|
||||
|
||||
# by matmul, inputs should be deallocated
|
||||
if func is aten.mm.out:
|
||||
matmul_seen = True
|
||||
test_self.assertEqual(len(inps), 0)
|
||||
test_self.assertIsNone(inp_refs[0]())
|
||||
test_self.assertIsNone(inp_refs[1]())
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with TestRefMode():
|
||||
fn_compiled(inps)
|
||||
|
||||
# for some reason, TorchDispatch doesnt capture the
|
||||
# cuda mm call (even without cudagraphs)
|
||||
if self.device == "cpu":
|
||||
self.assertTrue(matmul_seen)
|
||||
else:
|
||||
self.assertEqual(len(inps), 0)
|
||||
|
||||
|
||||
if HAS_CPU:
|
||||
|
||||
@ -3781,7 +3851,7 @@ if HAS_CPU:
|
||||
fn_fx = make_fx(fn)(x1, y)
|
||||
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
|
||||
fn(x2, y)
|
||||
fn_compiled(x3, y)
|
||||
fn_compiled([x3, y])
|
||||
assert same(x2, x3)
|
||||
|
||||
def test_no_op_squeeze(self):
|
||||
@ -3872,7 +3942,7 @@ if HAS_CUDA:
|
||||
]
|
||||
mod = make_fx(forward)(*inps)
|
||||
compiled = compile_fx_inner(mod, inps)
|
||||
compiled(*inps)
|
||||
compiled(inps)
|
||||
|
||||
@patch.object(config, "fallback_random", True)
|
||||
def test_dtype_factory_issue(self):
|
||||
@ -3888,7 +3958,7 @@ if HAS_CUDA:
|
||||
|
||||
mod = make_fx(forward)()
|
||||
compiled = compile_fx_inner(mod, ())
|
||||
assert compiled()[0].device.type == "cuda"
|
||||
assert compiled([])[0].device.type == "cuda"
|
||||
|
||||
@patch.object(config.triton, "cudagraphs", True)
|
||||
def test_expanded_inputs_cudagraphs(self):
|
||||
@ -3952,6 +4022,7 @@ if HAS_CUDA:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.testing import run_tests
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests(needs="filelock")
|
||||
if HAS_CPU or HAS_CUDA:
|
||||
run_tests(needs="filelock")
|
||||
|
@ -33,7 +33,7 @@ try:
|
||||
from .test_torchinductor import check_model, check_model_cuda
|
||||
except ImportError:
|
||||
from test_torchinductor import check_model, check_model_cuda
|
||||
except (unittest.SkipTest, ImportError) as e:
|
||||
except (unittest.SkipTest, ImportError, AssertionError) as e:
|
||||
sys.stderr.write(f"{type(e)}: {e}\n")
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
@ -154,6 +154,9 @@ inductor_skips["cuda"] = {
|
||||
"jiterator_binary": {b8, f16, f32, f64, i32, i64},
|
||||
"jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64},
|
||||
"jiterator_unary": {b8, f16, f32, f64, i32, i64},
|
||||
# Triton bug leads to segfault
|
||||
"nn.functional.softplus": {f64},
|
||||
"nn.functional.mish": {f64},
|
||||
}
|
||||
|
||||
inductor_expected_failures_single_sample = defaultdict(dict)
|
||||
@ -372,24 +375,13 @@ inductor_expected_failures_single_sample["cuda"] = {
|
||||
inductor_gradient_expected_failures_single_sample = defaultdict(dict)
|
||||
|
||||
inductor_gradient_expected_failures_single_sample["cuda"] = {
|
||||
"amax": {f16, f32, f64},
|
||||
"amin": {f16, f32, f64},
|
||||
"asin": {f16},
|
||||
"cumprod": {f16},
|
||||
"linalg.vector_norm": {f64, f64},
|
||||
"linalg.householder_product": {f32},
|
||||
"linalg.lu": {f32, f64},
|
||||
"kron": {f16},
|
||||
"masked.amax": {f16, f32, f64},
|
||||
"masked.amin": {f16, f32, f64},
|
||||
"max.reduction_no_dim": {f16, f32, f64},
|
||||
"median": {f16, f32, f64},
|
||||
"min.reduction_no_dim": {f16, f32, f64},
|
||||
"nan_to_num": {f16, f32, f64},
|
||||
"nanmean": {f16, f32, f64},
|
||||
"nanmedian": {f16, f32, f64},
|
||||
"nanquantile": {f32, f64},
|
||||
"nansum": {f16, f32, f64},
|
||||
"native_batch_norm": {f16, f32, f64},
|
||||
"native_layer_norm": {f16, f32, f64},
|
||||
"nn.functional._scaled_dot_product_attention": {f16},
|
||||
@ -446,6 +438,7 @@ inductor_override_kwargs = {
|
||||
"new_empty_strided": {"assert_equal": False},
|
||||
"randn": {"assert_equal": False},
|
||||
("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001},
|
||||
("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002},
|
||||
"gradient": {"check_gradient": False}, # segfault on check_gradient
|
||||
# Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error
|
||||
"linalg.solve_triangular": {"check_gradient": False},
|
||||
@ -461,6 +454,8 @@ inductor_all_samples = {
|
||||
"index_copy",
|
||||
"scatter_reduce.sum",
|
||||
"select_scatter",
|
||||
"squeeze",
|
||||
"unsqueeze",
|
||||
}
|
||||
|
||||
|
||||
|
@ -54,8 +54,6 @@ constant_functions = {
|
||||
torch.onnx.is_in_onnx_export: False,
|
||||
}
|
||||
|
||||
# root folder of the project
|
||||
base_dir = dirname(dirname(dirname(abspath(__file__))))
|
||||
|
||||
# don't specialize on shapes and strides and put shape ops in graph
|
||||
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
|
||||
@ -152,6 +150,12 @@ dynamo_import = __name__.replace(".config", "")
|
||||
# How to import torchinductor, either torchinductor or torch.inductor
|
||||
inductor_import = dynamo_import.replace("dynamo", "inductor")
|
||||
|
||||
# root folder of the project
|
||||
if "torch." in dynamo_import:
|
||||
base_dir = dirname(dirname(dirname(abspath(__file__))))
|
||||
else:
|
||||
base_dir = dirname(dirname(abspath(__file__)))
|
||||
|
||||
|
||||
class _AccessLimitingConfig(ModuleType):
|
||||
def __setattr__(self, name, value):
|
||||
|
@ -228,7 +228,7 @@ def save_graph_repro(fd, gm, args, compiler_name):
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
|
||||
compiled(*args)
|
||||
compiled(args)
|
||||
"""
|
||||
)
|
||||
)
|
||||
@ -293,7 +293,7 @@ def inductor_fails(fx_g, args, check_str=None):
|
||||
|
||||
try:
|
||||
compile_mod = compile_fx_inner(fx_g, args)
|
||||
compile_mod(*args)
|
||||
compile_mod(args)
|
||||
except Exception as e:
|
||||
if check_str is not None and check_str not in repr(e):
|
||||
return False
|
||||
@ -385,7 +385,7 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str):
|
||||
orig_graph = copy.deepcopy(gm.graph)
|
||||
assert config.repro_after in ("dynamo", "aot", None)
|
||||
|
||||
def deferred_for_real_inputs(*real_inputs):
|
||||
def deferred_for_real_inputs(real_inputs):
|
||||
"""
|
||||
Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
|
||||
example_inputs can be fake tensors. We can call compiler_fn (which is
|
||||
@ -420,14 +420,14 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str):
|
||||
raise ValueError("Bad accuracy detected")
|
||||
else:
|
||||
# Call the compiled function with real inputs
|
||||
return compiled_fn(*real_inputs)
|
||||
return compiled_fn(real_inputs)
|
||||
else:
|
||||
try:
|
||||
# Call the compiler_fn - which is either aot_autograd or inductor
|
||||
# with fake inputs
|
||||
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
|
||||
# Call the compiled function with real inputs
|
||||
return compiled_fn(*real_inputs)
|
||||
return compiled_fn(real_inputs)
|
||||
except Exception as e:
|
||||
if config.repro_level == 1:
|
||||
dump_compiler_graph_state(
|
||||
@ -441,6 +441,7 @@ def wrap_compiler_debug(compiler_fn, compiler_name: str):
|
||||
|
||||
if config.repro_after == "aot":
|
||||
compiled_fn = deferred_for_real_inputs
|
||||
compiled_fn._boxed_call = True
|
||||
else:
|
||||
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
|
||||
|
||||
@ -453,6 +454,8 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False):
|
||||
"""
|
||||
Runs a forward and possibly backward iteration for a given mod and args.
|
||||
"""
|
||||
from functorch._src.aot_autograd import make_boxed_func
|
||||
|
||||
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
|
||||
|
||||
gm = copy.deepcopy(gm)
|
||||
@ -465,7 +468,14 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False):
|
||||
|
||||
if hasattr(gm, "zero_grad"):
|
||||
gm.zero_grad(True)
|
||||
out = gm(*args)
|
||||
|
||||
# TorchInductor returned callable expects lists. So, boxing the call.
|
||||
if not hasattr(gm, "_boxed_call"):
|
||||
orig_named_parameters = gm.named_parameters
|
||||
gm = make_boxed_func(gm)
|
||||
gm.named_parameters = orig_named_parameters
|
||||
|
||||
out = gm(args)
|
||||
if only_fwd:
|
||||
return out
|
||||
if requires_bwd_pass(out):
|
||||
@ -775,7 +785,7 @@ def wrap_backend_debug(compiler_fn, compiler_name: str):
|
||||
else:
|
||||
try:
|
||||
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
|
||||
run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
|
||||
run_fwd_maybe_bwd(compiled_gm, example_inputs)
|
||||
except Exception as exc:
|
||||
log.warning(
|
||||
"Compiled Fx GraphModule failed with following error. Setting up minifier."
|
||||
@ -815,7 +825,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name):
|
||||
|
||||
try:
|
||||
compiled_gm = compiler_fn(gm, example_inputs)
|
||||
run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
|
||||
run_fwd_maybe_bwd(compiled_gm, example_inputs)
|
||||
raise ValueError("No issue was detected")
|
||||
except Exception as exc:
|
||||
orig_failure = str(exc)
|
||||
|
@ -10,7 +10,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .. import config
|
||||
from ..utils import fake_tensors_available
|
||||
from ..utils import clone_inputs, fake_tensors_available
|
||||
|
||||
if fake_tensors_available:
|
||||
from torch._subclasses import FakeTensorMode # noqa: F401
|
||||
@ -52,6 +52,14 @@ class ShapeAliasingAndMutationProp(ShapeProp):
|
||||
n.meta["alias_groups"] = {
|
||||
self.tensor_alias_group(obj) for obj in self.extract_tensors(result)
|
||||
}
|
||||
|
||||
if (
|
||||
not n.meta["alias_groups"]
|
||||
and n.op == "call_function"
|
||||
and n.target == operator.setitem
|
||||
):
|
||||
n.meta["alias_groups"] = {self.tensor_alias_group(tensor_args[0])}
|
||||
|
||||
n.meta["mutates_alias_groups"] = {
|
||||
self.tensor_alias_group(tensor)
|
||||
for tensor, v1, v2 in zip(tensor_args, input_versions1, input_versions2)
|
||||
@ -113,6 +121,10 @@ def has_mutation(gm, example_inputs, inputs_only=False):
|
||||
true, we only check for mutation of inputs"""
|
||||
# TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way.
|
||||
|
||||
# Clone the inputs such that intermediate tensors (not leaf tensors) with
|
||||
# requires_grad to True are now converted to False to avoid Runtime Error
|
||||
# like "leaf variable that requires grad is inplace modified"
|
||||
example_inputs = clone_inputs(example_inputs)
|
||||
if fake_tensors_available and config.fake_tensor_propagation:
|
||||
with FakeTensorMode() as fake_mode:
|
||||
pass
|
||||
|
@ -681,9 +681,7 @@ def tvm_compile_inner(
|
||||
elif tuning_option == "meta_schedule":
|
||||
from os import path as osp
|
||||
|
||||
from tvm.meta_schedule import TuneConfig
|
||||
from tvm.meta_schedule.database import JSONDatabase
|
||||
from tvm.meta_schedule.tune import tune_relay
|
||||
from tvm.contrib.torch import optimize_torch
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
if log_file is not None:
|
||||
@ -691,22 +689,16 @@ def tvm_compile_inner(
|
||||
log_file
|
||||
), "TVM's meta_schedule requires a directory for storing log files."
|
||||
work_dir = log_file
|
||||
lib: tvm.runtime.Module = tune_relay(
|
||||
mod=mod,
|
||||
params=params,
|
||||
target=target,
|
||||
config=TuneConfig(
|
||||
strategy="evolutionary",
|
||||
num_trials_per_iter=64,
|
||||
max_trials_per_task=trials,
|
||||
max_trials_global=trials,
|
||||
),
|
||||
|
||||
lib = optimize_torch(
|
||||
jit_mod,
|
||||
example_inputs,
|
||||
max_trials_global=20000,
|
||||
work_dir=work_dir,
|
||||
database=JSONDatabase(
|
||||
osp.join(work_dir, "workload.json"),
|
||||
osp.join(work_dir, "records.json"),
|
||||
),
|
||||
target=target,
|
||||
max_trials_per_task=64,
|
||||
)
|
||||
|
||||
elif tuning_option is None:
|
||||
# no autotuning (for debugging)
|
||||
with tvm.transform.PassContext(opt_level=10):
|
||||
@ -716,34 +708,41 @@ def tvm_compile_inner(
|
||||
"This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
|
||||
"There are three available options including None, auto_scheduler and meta_schedule."
|
||||
)
|
||||
if tune_option != "meta_schedule":
|
||||
m = graph_executor.GraphModule(lib["default"](dev))
|
||||
|
||||
m = graph_executor.GraphModule(lib["default"](dev))
|
||||
def to_torch_tensor(nd_tensor):
|
||||
"""A helper function to transfer a NDArray to torch.tensor."""
|
||||
if nd_tensor.dtype == "bool":
|
||||
# DLPack does not support boolean so it can't be handled by
|
||||
# torch.utils.dlpack.from_pack. Workaround by going through
|
||||
# numpy, although this brings additional data copy overhead.
|
||||
return torch.from_numpy(nd_tensor.numpy())
|
||||
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
|
||||
|
||||
def to_torch_tensor(nd_tensor):
|
||||
"""A helper function to transfer a NDArray to torch.tensor."""
|
||||
if nd_tensor.dtype == "bool":
|
||||
# DLPack does not support boolean so it can't be handled by
|
||||
# torch.utils.dlpack.from_pack. Workaround by going through
|
||||
# numpy, although this brings additional data copy overhead.
|
||||
return torch.from_numpy(nd_tensor.numpy())
|
||||
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
|
||||
def exec_tvm(*args):
|
||||
args = [a.contiguous() for a in args]
|
||||
for idx, arg in enumerate(args, 0):
|
||||
if arg.dim() != 0:
|
||||
if arg.requires_grad:
|
||||
arg = arg.detach()
|
||||
m.set_input(
|
||||
f"inp_{idx}",
|
||||
tvm.nd.array(arg.numpy(), dev),
|
||||
)
|
||||
m.run()
|
||||
return [
|
||||
to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())
|
||||
]
|
||||
|
||||
def exec_tvm(*args):
|
||||
args = [a.contiguous() for a in args]
|
||||
for idx, arg in enumerate(args, 0):
|
||||
if arg.dim() != 0:
|
||||
if arg.requires_grad:
|
||||
arg = arg.detach()
|
||||
m.set_input(
|
||||
f"inp_{idx}",
|
||||
tvm.nd.array(arg.numpy(), dev),
|
||||
)
|
||||
m.run()
|
||||
return [
|
||||
to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())
|
||||
]
|
||||
else:
|
||||
|
||||
def exec_tvm(*args):
|
||||
args = [a.contiguous() for a in args]
|
||||
return lib(*args)
|
||||
|
||||
return exec_tvm
|
||||
|
||||
except Exception:
|
||||
log.exception("tvm error")
|
||||
return jit_mod # explicit fall back to eager
|
||||
|
@ -43,14 +43,14 @@ class DDPOptimizer:
|
||||
bucket_actual_sizes = []
|
||||
node_splits = [[]]
|
||||
for node in reversed(gm.graph.nodes):
|
||||
if node.op == "output" or node.op == "placeholder":
|
||||
continue
|
||||
|
||||
if bucket_bytes >= self.bucket_bytes_cap:
|
||||
bucket_actual_sizes.insert(0, bucket_bytes)
|
||||
bucket_bytes = 0
|
||||
node_splits.insert(0, [])
|
||||
|
||||
if node.op == "output" or node.op == "placeholder":
|
||||
continue
|
||||
|
||||
elif node.op == "call_module":
|
||||
target = gm.get_submodule(node.target)
|
||||
params_size_b = sum(
|
||||
@ -62,6 +62,10 @@ class DDPOptimizer:
|
||||
)
|
||||
bucket_bytes += params_size_b
|
||||
# print(f"accumulated {params_size_b} b from {node}")
|
||||
elif node.op == "get_attr":
|
||||
maybe_param = getattr(gm, node.target)
|
||||
if maybe_param.requires_grad:
|
||||
bucket_bytes += maybe_param.storage().nbytes()
|
||||
else:
|
||||
# TODO(whc) confirm this:
|
||||
# (e.g. call_method, call_function aren't expected to 'have' parameters)
|
||||
|
@ -229,7 +229,12 @@ class OutputGraph(fx.Tracer):
|
||||
return wrap_name(k)
|
||||
|
||||
# create a new unique name
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", "_".join(map(str, names)))
|
||||
name = "_".join(map(str, names))
|
||||
# e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv
|
||||
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
|
||||
# e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
||||
|
||||
if not name or not name[0].isalpha():
|
||||
name = "sub" + name
|
||||
base = name
|
||||
|
71
torch/_dynamo/test_case.py
Normal file
71
torch/_dynamo/test_case.py
Normal file
@ -0,0 +1,71 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.testing
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
|
||||
from . import config, reset, utils
|
||||
|
||||
|
||||
def run_tests(needs=()):
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
if (
|
||||
TEST_WITH_TORCHDYNAMO
|
||||
or IS_WINDOWS
|
||||
or TEST_WITH_CROSSREF
|
||||
or TEST_WITH_ROCM
|
||||
or sys.version_info >= (3, 11)
|
||||
):
|
||||
return # skip testing
|
||||
|
||||
if isinstance(needs, str):
|
||||
needs = (needs,)
|
||||
for need in needs:
|
||||
if need == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
else:
|
||||
try:
|
||||
importlib.import_module(need)
|
||||
except ImportError:
|
||||
return
|
||||
run_tests()
|
||||
|
||||
|
||||
class TestCase(TorchTestCase):
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._exit_stack.close()
|
||||
super().tearDownClass()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack = contextlib.ExitStack()
|
||||
cls._exit_stack.enter_context(
|
||||
patch.object(config, "raise_on_backend_error", True)
|
||||
)
|
||||
cls._exit_stack.enter_context(
|
||||
patch.object(config, "raise_on_ctx_manager_usage", True)
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
reset()
|
||||
utils.counters.clear()
|
||||
|
||||
def tearDown(self):
|
||||
for k, v in utils.counters.items():
|
||||
print(k, v.most_common())
|
||||
reset()
|
||||
utils.counters.clear()
|
||||
super().tearDown()
|
@ -1,19 +1,16 @@
|
||||
import contextlib
|
||||
import dis
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os.path
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils
|
||||
from torch import fx
|
||||
|
||||
from . import config, eval_frame, optimize_assert, reset, utils
|
||||
from . import config, eval_frame, optimize_assert, reset
|
||||
from .bytecode_transformation import (
|
||||
create_instruction,
|
||||
debug_checks,
|
||||
@ -29,37 +26,6 @@ three = 3
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_tests(needs=()):
|
||||
return # TEMPORARY: disable all tests
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
)
|
||||
|
||||
if (
|
||||
TEST_WITH_TORCHDYNAMO
|
||||
or IS_WINDOWS
|
||||
or TEST_WITH_CROSSREF
|
||||
or sys.version_info >= (3, 11)
|
||||
):
|
||||
return # skip testing
|
||||
|
||||
if isinstance(needs, str):
|
||||
needs = (needs,)
|
||||
for need in needs:
|
||||
if need == "cuda" and not torch.cuda.is_available():
|
||||
return
|
||||
else:
|
||||
try:
|
||||
importlib.import_module(need)
|
||||
except ImportError:
|
||||
return
|
||||
run_tests()
|
||||
|
||||
|
||||
def clone_me(x):
|
||||
if x is None:
|
||||
return None
|
||||
@ -229,36 +195,6 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
|
||||
self.assertEqual(actual.op_count, expected_ops)
|
||||
|
||||
|
||||
class TestCase(torch.testing._internal.common_utils.TestCase):
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._exit_stack.close()
|
||||
super().tearDownClass()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._exit_stack = contextlib.ExitStack()
|
||||
cls._exit_stack.enter_context(
|
||||
patch.object(config, "raise_on_backend_error", True)
|
||||
)
|
||||
cls._exit_stack.enter_context(
|
||||
patch.object(config, "raise_on_ctx_manager_usage", True)
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
reset()
|
||||
utils.counters.clear()
|
||||
|
||||
def tearDown(self):
|
||||
for k, v in utils.counters.items():
|
||||
print(k, v.most_common())
|
||||
reset()
|
||||
utils.counters.clear()
|
||||
super().tearDown()
|
||||
|
||||
|
||||
def dummy_fx_compile(gm: fx.GraphModule, example_inputs):
|
||||
return gm.forward
|
||||
|
||||
|
@ -42,6 +42,10 @@ from .constant import ConstantVariable
|
||||
from .lists import ShapeVariable, SizeVariable
|
||||
|
||||
|
||||
class _missing:
|
||||
pass
|
||||
|
||||
|
||||
class TensorVariable(VariableTracker):
|
||||
"""A torch.Tensor input or an intermediate value in the FX graph"""
|
||||
|
||||
@ -189,8 +193,9 @@ class TensorVariable(VariableTracker):
|
||||
elif istype(example_value, int) and proxy.node.target in (
|
||||
torch.seed,
|
||||
operator.mod,
|
||||
torch.distributed.get_rank,
|
||||
torch.distributed.get_world_size,
|
||||
# some mac builds are missing torch.distributed.get_rank()
|
||||
getattr(torch.distributed, "get_rank", _missing),
|
||||
getattr(torch.distributed, "get_world_size", _missing),
|
||||
):
|
||||
proxy.node.meta["example_value"] = example_value
|
||||
return DynamicShapeVariable(proxy, type(example_value), **options)
|
||||
|
@ -654,6 +654,7 @@ class TritonKernel(Kernel):
|
||||
def indexing(
|
||||
self,
|
||||
index: sympy.Expr,
|
||||
*,
|
||||
copy_shape=None,
|
||||
dense_indexing=False,
|
||||
):
|
||||
@ -686,9 +687,11 @@ class TritonKernel(Kernel):
|
||||
mask.append(f"{tree.prefix}mask")
|
||||
dense_mask.append(f"{tree.prefix}mask")
|
||||
|
||||
if (need_dense and not have_dense) or index == 0:
|
||||
if (need_dense and not have_dense) or isinstance(
|
||||
index, sympy.core.numbers.Integer
|
||||
):
|
||||
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
|
||||
if index == 0:
|
||||
if isinstance(index, sympy.core.numbers.Integer):
|
||||
return index_str, "None"
|
||||
else:
|
||||
mask = dense_mask
|
||||
@ -779,7 +782,7 @@ class TritonKernel(Kernel):
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
var = self.args.output(name)
|
||||
index, mask = self.indexing(index, value, dense_indexing=True)
|
||||
index, mask = self.indexing(index, dense_indexing=True)
|
||||
if mode is None:
|
||||
line = f"tl.store({var} + ({index}), {value}, {mask})"
|
||||
elif mode == "atomic_add":
|
||||
@ -861,7 +864,7 @@ class TritonKernel(Kernel):
|
||||
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
|
||||
self.suffix.writeline(f"{result_var} = {var_name}")
|
||||
self.inside_reduction = False
|
||||
index, mask = self.indexing(index, result_var)
|
||||
index, mask = self.indexing(index)
|
||||
assert "rmask" not in index
|
||||
self.inside_reduction = True
|
||||
self.outside_loop_vars.add(result_var)
|
||||
|
@ -77,7 +77,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||
|
||||
def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=True):
|
||||
# use dense_indexing for TritonTemplateKernel to avoid map::at error
|
||||
return super().indexing(index, copy_shape, dense_indexing)
|
||||
return super().indexing(
|
||||
index, copy_shape=copy_shape, dense_indexing=dense_indexing
|
||||
)
|
||||
|
||||
def codegen_body(
|
||||
self, name, fuse, could_remove_kernel_buf, kernel_buf_replace_name
|
||||
|
@ -217,15 +217,20 @@ class WrapperCodeGen(CodeGen):
|
||||
)
|
||||
|
||||
self.prefix.splice(
|
||||
f"""
|
||||
"""
|
||||
|
||||
async_compile.wait(globals())
|
||||
del async_compile
|
||||
|
||||
def call({', '.join(V.graph.graph_inputs.keys())}):
|
||||
def call(args):
|
||||
"""
|
||||
)
|
||||
with self.prefix.indent():
|
||||
inp_len = len(V.graph.graph_inputs.keys())
|
||||
if inp_len != 0:
|
||||
lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
|
||||
self.prefix.writeline(f"{lhs} = args")
|
||||
self.prefix.writeline("args.clear()")
|
||||
for name in V.graph.randomness_seeds:
|
||||
self.prefix.writeline(
|
||||
f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
|
||||
@ -275,6 +280,12 @@ class WrapperCodeGen(CodeGen):
|
||||
|
||||
def codegen_free(self, buffer):
|
||||
name = buffer.get_name()
|
||||
|
||||
# can be freed but not reused
|
||||
if isinstance(buffer, ir.InputBuffer):
|
||||
self.writeline(f"del {name}")
|
||||
return
|
||||
|
||||
if not self.can_reuse(buffer):
|
||||
return
|
||||
self.freed.add(name)
|
||||
@ -380,7 +391,7 @@ class WrapperCodeGen(CodeGen):
|
||||
)
|
||||
|
||||
output.writeline(
|
||||
f"print_performance(lambda: call({', '.join(V.graph.graph_inputs.keys())}))"
|
||||
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
|
||||
)
|
||||
|
||||
def define_kernel(self, name: str, kernel: str):
|
||||
|
@ -5,7 +5,8 @@ import logging
|
||||
from typing import List
|
||||
|
||||
import functorch
|
||||
from functorch.compile import make_boxed_compiler, min_cut_rematerialization_partition
|
||||
from functorch._src.aot_autograd import make_boxed_func
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
import torch.fx
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
@ -86,7 +87,7 @@ def compile_fx_inner(
|
||||
graph_id=None,
|
||||
):
|
||||
if dynamo_utils.count_calls(gm.graph) == 0:
|
||||
return gm
|
||||
return make_boxed_func(gm.forward)
|
||||
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
@ -137,6 +138,9 @@ def compile_fx_inner(
|
||||
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
||||
f"graph {graph_id}",
|
||||
)
|
||||
|
||||
# aot autograd needs to know to pass in inputs as a list
|
||||
result._boxed_call = True
|
||||
return result
|
||||
|
||||
|
||||
@ -159,14 +163,15 @@ def align_inputs(model, inputs, static_input_idxs=()):
|
||||
if len(check_inputs) == 0:
|
||||
return model
|
||||
|
||||
def run(*new_inputs):
|
||||
def run(new_inputs):
|
||||
for i in check_inputs:
|
||||
if new_inputs[i].data_ptr() % ALIGNMENT:
|
||||
if isinstance(new_inputs, tuple):
|
||||
new_inputs = list(new_inputs)
|
||||
new_inputs[i] = clone_preserve_strides(new_inputs[i])
|
||||
new_inputs = [x.to("cuda") if is_unspec_input(x) else x for x in new_inputs]
|
||||
return model(*new_inputs)
|
||||
new_inputs_to_cuda = [
|
||||
x.to("cuda") if is_unspec_input(x) else x for x in new_inputs
|
||||
]
|
||||
new_inputs.clear()
|
||||
return model(new_inputs_to_cuda)
|
||||
|
||||
return run
|
||||
|
||||
@ -179,13 +184,13 @@ def cudagraphify(model, inputs, static_input_idxs=()):
|
||||
|
||||
compiled_fn = None
|
||||
|
||||
def run(*new_inputs):
|
||||
def run(new_inputs):
|
||||
nonlocal compiled_fn
|
||||
if compiled_fn is None:
|
||||
with dynamo_utils.preserve_rng_state():
|
||||
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
|
||||
|
||||
return compiled_fn(*new_inputs)
|
||||
return compiled_fn(new_inputs)
|
||||
|
||||
return run
|
||||
|
||||
@ -239,8 +244,9 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
|
||||
torch.cuda.synchronize()
|
||||
stream = torch.cuda.Stream()
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
# copy static_inputs because it will be cleared in model
|
||||
with torch.cuda.stream(stream):
|
||||
model(*static_inputs)
|
||||
model(list(static_inputs))
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
torch.cuda.synchronize()
|
||||
@ -248,13 +254,13 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
|
||||
# record
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
static_outputs = model(*static_inputs)
|
||||
static_outputs = model(list(static_inputs))
|
||||
if not isinstance(static_outputs, (list, tuple)):
|
||||
static_outputs = (static_outputs,)
|
||||
|
||||
if config.size_asserts:
|
||||
|
||||
def run(*new_inputs):
|
||||
def run(new_inputs):
|
||||
assert len(static_inputs) == len(new_inputs)
|
||||
for idx, (dst, src, expanded_dims) in enumerate(
|
||||
zip(static_inputs, new_inputs, inps_expanded_dims)
|
||||
@ -268,6 +274,7 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
|
||||
dst = index_expanded_dims(dst, expanded_dims)
|
||||
src = index_expanded_dims(src, expanded_dims)
|
||||
dst.copy_(src)
|
||||
new_inputs.clear()
|
||||
graph.replay()
|
||||
return static_outputs
|
||||
|
||||
@ -276,11 +283,12 @@ def cudagraphify_impl(model, inputs, static_input_idxs=()):
|
||||
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
||||
]
|
||||
|
||||
def run(*new_inputs):
|
||||
def run(new_inputs):
|
||||
for idx in copy_indices:
|
||||
src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx])
|
||||
dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx])
|
||||
dst.copy_(src)
|
||||
new_inputs.clear()
|
||||
graph.replay()
|
||||
return static_outputs
|
||||
|
||||
@ -359,8 +367,8 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]
|
||||
return aot_autograd(
|
||||
model_,
|
||||
example_inputs_,
|
||||
fw_compiler=make_boxed_compiler(fw_compiler),
|
||||
bw_compiler=make_boxed_compiler(bw_compiler),
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
decompositions=select_decomp_table(),
|
||||
partition_fn=functools.partial(
|
||||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
|
@ -3,8 +3,6 @@ import logging
|
||||
import math
|
||||
import numbers
|
||||
|
||||
from functorch._src.aot_autograd import aot_autograd_decompositions
|
||||
|
||||
import torch
|
||||
import torch._decomp as decomp
|
||||
from torch import Tensor
|
||||
@ -98,9 +96,10 @@ decompositions = get_decompositions(
|
||||
aten.tril.default,
|
||||
aten.upsample_bilinear2d.vec,
|
||||
aten.upsample_nearest2d_backward,
|
||||
aten.softplus,
|
||||
aten.softplus_backward,
|
||||
]
|
||||
)
|
||||
decompositions.update(aot_autograd_decompositions)
|
||||
|
||||
|
||||
def register_decomposition(ops):
|
||||
|
@ -154,6 +154,11 @@ def _register_lowering(
|
||||
@functools.wraps(decomp_fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
args = list(args)
|
||||
unpacked = False
|
||||
# TODO maybe we need to use pytrees here
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
unpacked = True
|
||||
args = args[0]
|
||||
# Only look at args that are Tensors
|
||||
indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
|
||||
# kwargs tensors not supported yet
|
||||
@ -170,14 +175,20 @@ def _register_lowering(
|
||||
dtype = get_promoted_dtype(
|
||||
*promoting_args, type_promotion_kind=type_promotion_kind
|
||||
)
|
||||
for i in indices:
|
||||
args[i] = to_dtype(args[i], dtype)
|
||||
# sometimes args are an immutable list so we can't mutate them
|
||||
new_args = []
|
||||
for i in range(len(args)):
|
||||
if isinstance(args[i], ir.Constant):
|
||||
args[i] = ir.Constant(
|
||||
args[i].value, dtype, args[indices[0]].get_device()
|
||||
if i in indices:
|
||||
new_args.append(to_dtype(args[i], dtype))
|
||||
elif isinstance(args[i], ir.Constant):
|
||||
new_args.append(
|
||||
ir.Constant(args[i].value, dtype, args[indices[0]].get_device())
|
||||
)
|
||||
|
||||
else:
|
||||
new_args.append(args[i])
|
||||
args = new_args
|
||||
if unpacked:
|
||||
args = [args]
|
||||
if broadcast and indices:
|
||||
for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
|
||||
args[i] = x
|
||||
@ -475,12 +486,13 @@ def squeeze(x, dim=None):
|
||||
assert isinstance(x, TensorBox)
|
||||
if dim is None:
|
||||
return TensorBox(SqueezeView.create(x.data))
|
||||
|
||||
dim = _validate_dim(x, dim, 0)
|
||||
offset = len(x.get_size()) == 0
|
||||
dim = _validate_dim(x, dim, offset)
|
||||
new_shape = list(x.get_size())
|
||||
removed = new_shape.pop(dim)
|
||||
if V.graph.sizevars.maybe_guard_equals(removed, 1):
|
||||
return view(x, new_shape)
|
||||
if len(new_shape) > 0:
|
||||
removed = new_shape.pop(dim)
|
||||
if V.graph.sizevars.maybe_guard_equals(removed, 1):
|
||||
return view(x, new_shape)
|
||||
|
||||
# squeeze does nothing if the size isn't 1
|
||||
return x
|
||||
|
@ -988,6 +988,11 @@ class Scheduler:
|
||||
node = self.name_to_node[name]
|
||||
if node.can_free():
|
||||
V.graph.wrapper_code.codegen_free(node.node)
|
||||
elif name in V.graph.graph_inputs:
|
||||
storage = V.graph.graph_inputs[name].data
|
||||
assert storage.is_input_buffer()
|
||||
V.graph.wrapper_code.codegen_free(storage.data)
|
||||
|
||||
self.buffer_names_to_free.clear()
|
||||
|
||||
def remove_kernel_local_buffers(self):
|
||||
|
@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os.path
|
||||
import re
|
||||
import threading
|
||||
from typing import List
|
||||
|
||||
@ -110,11 +111,12 @@ class CachingAutotuner(KernelInterface):
|
||||
# set_device(current_device()) # TODO(jansel): is this needed?
|
||||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
|
||||
stream, bin.cu_function, None, None, None,
|
||||
{', '.join(call_args)})
|
||||
stream, bin.cu_function, None, None, None,
|
||||
{', '.join(call_args)})
|
||||
""".lstrip(),
|
||||
scope,
|
||||
)
|
||||
|
||||
launcher = scope["launcher"]
|
||||
launcher.config = cfg
|
||||
return launcher
|
||||
@ -160,11 +162,22 @@ class CachingAutotuner(KernelInterface):
|
||||
launcher.config.pre_hook(
|
||||
{**zip(self.arg_names, args), **launcher.config.kwargs}
|
||||
)
|
||||
return launcher(
|
||||
*args,
|
||||
grid=grid,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
result = launcher(
|
||||
*args,
|
||||
grid=grid,
|
||||
stream=stream,
|
||||
)
|
||||
except TypeError as e:
|
||||
if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)):
|
||||
raise RuntimeError(
|
||||
"""Consider updating Triton with
|
||||
`pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def hash_configs(configs: List[Config]):
|
||||
|
Reference in New Issue
Block a user