Compare commits

...

4 Commits

Author SHA1 Message Date
bbf5ebbd4b [compiled autograd] torch.compile API
ghstack-source-id: 28065ffb2ed01641b4dcd31fb8fc0e729192f9ec
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125880
2024-05-16 01:00:03 -07:00
463913e679 [compiled autograd] clear compiled_autograd_verbose once test is done
ghstack-source-id: f817bd618a06b97a86ef1262dd457cf19879c548
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126148
2024-05-14 08:30:58 -07:00
643f57a782 [inductor] Clear cache on ctx manager exit
ghstack-source-id: 9146aea2680868b25af31a9271a7aa0a396668af
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126146
2024-05-14 08:30:57 -07:00
f13bfd8d87 [compiled autograd] Fix flaky tests
ghstack-source-id: 9e999edf4e9a1e41c381fdf20063338a6eb2f313
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126144
2024-05-14 08:30:57 -07:00
7 changed files with 231 additions and 6 deletions

View File

@ -4934,6 +4934,22 @@ def forward(self, primals_1, primals_2):
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
self.assertEqual(opt_ladder(data), ladder(data))
def test_issue126128(self):
def fn():
x = torch.randn(1, 10)
y = torch.randn(10, 1)
return torch.mm(x, y).sum()
def fn2():
x = torch.randn(10, 100)
y = torch.randn(100, 10)
return torch.mm(x, y).sum()
with torch._inductor.utils.fresh_inductor_cache():
torch.compile(fn)()
torch.compile(fn2)()
instantiate_parametrized_tests(ReproTests)

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import functools
import logging
import re
import sys
import unittest
@ -10,7 +11,7 @@ from unittest import mock
import torch
import torch.nn as nn
from torch import _inductor as inductor
from torch._dynamo import compiled_autograd
from torch._dynamo import compiled_autograd, config
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@ -51,6 +52,18 @@ def hook3(gI, gO):
class TestCompiledAutograd(TestCase):
def setUp(self) -> None:
super().setUp()
torch._logging.set_logs(compiled_autograd_verbose=False)
config.compiled_autograd = False
compiled_autograd.reset()
def tearDown(self) -> None:
super().tearDown()
torch._logging.set_logs(compiled_autograd_verbose=False)
config.compiled_autograd = False
compiled_autograd.reset()
def check_output_and_recompiles(
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
):
@ -221,6 +234,115 @@ main()
self.check_output_and_recompiles(fn)
def test_torch_compile_api_inductor(self):
def fn():
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
res = []
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
expected = fn()
with config.patch(compiled_autograd=True):
compiled_fn = torch.compile(fn)
actual = compiled_fn()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_api_aot_eager(self):
def fn():
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
res = []
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
expected = fn()
with config.patch(compiled_autograd=True):
compiled_fn = torch.compile(fn, backend="aot_eager")
actual = compiled_fn()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_api_eager(self):
def fn():
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
res = []
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
expected = fn()
with config.patch(compiled_autograd=True):
compiled_fn = torch.compile(fn, backend="eager")
actual = compiled_fn()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_multiple_torch_compile(self):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
x = torch.randn([1, 4])
def fn():
result = model(x).sum()
result.backward()
model2 = torch.nn.Linear(4, 4)
x2 = torch.randn([1, 4])
def fn2():
result = model2(x2).sum()
result.backward()
no_ca1 = torch.compile(fn)
no_ca1()
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
counters.clear()
with config.patch(compiled_autograd=True):
withca = torch.compile(fn2)
withca()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
counters.clear()
no_ca2 = torch.compile(fn)
no_ca2()
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
def test_dynamo_boxed(self):
def get_placeholders(gm_):
placeholders = []
@ -322,6 +444,7 @@ main()
handle.remove()
def test_inputs_aliasing_bytecode_stack_restore(self):
logging.getLogger().setLevel(logging.WARNING)
from torch.testing._internal.logging_tensor import LoggingTensor
# Create a graph that allows inputs stealing
@ -752,6 +875,52 @@ main()
self.check_output_and_recompiles(fn, count=2)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_logging_tensor_flaky(self) -> None:
# when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
# resulting in:
# - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
# - python: `TypeError: not all arguments converted during string formatting`
# 1. some triton involving test
def fn():
def _fn(x):
return x
x = torch.arange(
1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
)
out = _fn(x)
loss = out.sum()
loss.backward()
with compiled_autograd.enable(compiler_fn):
fn()
logging.getLogger().setLevel(
logging.WARNING
) # triton setup overwrote it to INFO
# 2. test_inputs_aliasing_bytecode_stack_restore
from torch.testing._internal.logging_tensor import LoggingTensor
def forward(inputs):
add = inputs[0] + 1
add_1 = add + inputs[1]
out = add_1.cpu()
return (out,)
gm = torch.fx.symbolic_trace(forward)
print(gm.print_readable())
torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
compiled_fn = torch.compile(gm)
inputs = [
torch.ones(1000000, dtype=torch.float32),
LoggingTensor(torch.ones(1)),
]
compiled_fn(inputs)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_custom_fn_output_metadata(self):
def my_compiler_fn(gm):

View File

@ -319,3 +319,11 @@ def disable():
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
# return to starting state of a new process
def reset() -> None:
compiled_autograd_enable = False
assert compiled_autograd_enabled_count == 0
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
torch._C._dynamo.compiled_autograd.set_verbose_logging(False)

View File

@ -438,6 +438,10 @@ fake_tensor_cache_crosscheck_enabled = (
# WARNING: this is an experimental flag and is subject to change.
_experimental_support_context_fn_in_torch_utils_checkpoint = False
# Enables the Compiled Autograd engine to trace .backward() calls made under torch.compile().
# Note: AOT Autograd will still trace joint graphs.
compiled_autograd = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -490,6 +490,7 @@ class OptimizeContext(_TorchDynamoContext):
export=False,
dynamic=None,
compiler_config=None,
rebuild_ctx: Optional[Callable[[Callable], OptimizeContext]] = None,
):
def on_enter():
install_generation_tagging_init()
@ -505,6 +506,17 @@ class OptimizeContext(_TorchDynamoContext):
compiler_config=compiler_config,
)
if config.compiled_autograd:
assert rebuild_ctx
def call_compiled_autograd():
compiler_fn = rebuild_ctx()
ctx = torch._dynamo.compiled_autograd.enable(compiler_fn)
ctx.__enter__()
return functools.partial(ctx.__exit__, None, None, None)
self.enter_exit_hooks.append(call_compiled_autograd)
class RunOnlyContext(_TorchDynamoContext):
def __init__(self):
@ -527,6 +539,7 @@ def _optimize_catch_errors(
export=False,
dynamic=None,
compiler_config=None,
rebuild_ctx=None,
):
return OptimizeContext(
convert_frame.catch_errors_wrapper(compile_fn, hooks),
@ -535,6 +548,7 @@ def _optimize_catch_errors(
export=export,
dynamic=dynamic,
compiler_config=compiler_config,
rebuild_ctx=rebuild_ctx,
)
@ -585,7 +599,15 @@ def is_inductor_supported():
return False
def optimize(
def optimize(*args, **kwargs):
def rebuild_ctx():
return optimize(*args, **kwargs)
return _optimize(rebuild_ctx, *args, **kwargs)
def _optimize(
rebuild_ctx: Callable[[Callable], OptimizeContext],
backend="inductor",
*,
nopython=False,
@ -593,7 +615,7 @@ def optimize(
guard_fail_fn=None,
disable=False,
dynamic=None,
):
) -> OptimizeContext:
"""
The main entrypoint of TorchDynamo. Do graph capture and call
backend() to optimize extracted graphs.
@ -641,6 +663,7 @@ def optimize(
backend,
dynamic=dynamic,
hooks=hooks,
rebuild_ctx=rebuild_ctx,
)
# The backend function is stashed in the callable returned by
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
@ -653,6 +676,7 @@ def optimize(
compiler_config=backend.get_compiler_config()
if hasattr(backend, "get_compiler_config")
else None,
rebuild_ctx=rebuild_ctx,
)
@ -1407,6 +1431,7 @@ def optimize_assert(
export=False,
export_constraints=None,
dynamic=None,
rebuild_ctx=None,
):
"""
The same as `torch._dynamo.optimize(backend, nopython=True)`
@ -1424,6 +1449,7 @@ def optimize_assert(
backend_ctx_ctor,
export=export,
dynamic=dynamic,
rebuild_ctx=rebuild_ctx,
)

View File

@ -721,6 +721,8 @@ def fresh_inductor_cache(cache_entries=None):
except Exception:
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
raise
finally:
clear_inductor_caches()
def argsort(seq) -> List[int]:

View File

@ -11,6 +11,7 @@ from torch.utils.weak import WeakTensorKeyDictionary
import functools
from torch._C._profiler import gather_traceback, symbolize_tracebacks
logger = logging.getLogger("LoggingTensor")
_dtype_abbrs = {
torch.bfloat16: "bf16",
@ -135,8 +136,8 @@ class LoggingTensorHandler(logging.Handler):
if self.tracebacks_list is not None:
self.tracebacks_list.append(record.traceback)
def log_input(name: str, var: object):
logging.getLogger("LoggingTensor").info("input", (name,), {}, var) # noqa: PLE1205
def log_input(name: str, var: object) -> None:
logger.info("input", (name,), {}, var) # noqa: PLE1205
class GatherTraceback(logging.Filter):
def __init__(self, python=True, script=True, cpp=False):
@ -151,7 +152,6 @@ class GatherTraceback(logging.Filter):
@contextlib.contextmanager
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
logger = logging.getLogger("LoggingTensor")
log_list: List[str] = []
tracebacks_list: List[str] = []
handler = LoggingTensorHandler(