Compare commits

...

5 Commits

Author SHA1 Message Date
03c6e90ad9 bench 2025-03-17 18:04:08 -07:00
61f5027048 [ca] fix accumulate grad polyfill when different strides between param and grad
ghstack-source-id: 89945a77c6f45cf2429a0b26787017e65e74fc75
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149367
2025-03-17 18:00:00 -07:00
1d5f0102ca [ca] use torch.compile ca API for benchmarks
ghstack-source-id: 7eb97b044bb276fb7ce2d03e9a3554637783db80
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148694
2025-03-17 17:59:57 -07:00
5398bbc31d [aot][ca] store a deepcopy of the bw graph in CompiledFunction
ghstack-source-id: 3f3e02e3cd4f8f50cf438948c3a55814ec9ed44d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149229
2025-03-17 17:59:56 -07:00
44b6c7914f [ca] fix dce for side-effects
ghstack-source-id: b6f8f4aef0595e8f43054a44ebe66013c4a0c6c3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149336
2025-03-17 17:59:56 -07:00
7 changed files with 98 additions and 81 deletions

View File

@ -49,18 +49,10 @@ from torch._logging.scribe import open_source_signpost
try:
from torch._dynamo.utils import (
clone_inputs,
graph_break_reasons,
maybe_enable_compiled_autograd,
)
from torch._dynamo.utils import clone_inputs, graph_break_reasons
from torch._inductor.utils import fresh_inductor_cache
except ImportError:
from _dynamo.utils import (
clone_inputs,
graph_break_reasons,
maybe_enable_compiled_autograd,
)
from _dynamo.utils import clone_inputs, graph_break_reasons
import torch._functorch.config
from torch._functorch.aot_autograd import set_model_name
@ -916,14 +908,7 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg
# inputs will incur high penalty then the next one.
maybe_mark_step(args)
with (
maybe_mark_profile(p=p, mark=mark),
maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
):
with maybe_mark_profile(p=p, mark=mark):
timings[rep], actual_output = timed(
model,
model_iter_fn,
@ -1093,14 +1078,7 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
# call mark_step between the 2 calls to make the comparison fair.
maybe_mark_step(args)
with (
maybe_mark_profile(p=p, mark="actual"),
maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
):
with maybe_mark_profile(p=p, mark="actual"):
timings[rep, 1], actual_output = timed(
model,
frozen_model_iter_fn,
@ -2234,14 +2212,9 @@ class BenchmarkRunner:
new_result = optimized_model_iter_fn(model_copy, example_inputs)
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
):
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
except Exception as e:
log.exception("")
print(
@ -2463,15 +2436,8 @@ class BenchmarkRunner:
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with (
maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
with maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"
@ -2489,12 +2455,7 @@ class BenchmarkRunner:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU]
) as prof:
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
):
warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
events = list(
filter(
@ -2623,15 +2584,8 @@ class BenchmarkRunner:
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with (
maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
with maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"
@ -2649,12 +2603,7 @@ class BenchmarkRunner:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU]
) as prof:
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
):
warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
events = list(
filter(
@ -3241,6 +3190,7 @@ def parse_args(args=None):
"--compiled-autograd",
action="store_true",
help="Enables compiled autograd on compiled benchmark",
default=True,
)
parser.add_argument(
@ -3520,6 +3470,7 @@ def run(runner, args, original_dir=None):
args.exclude = args.exclude or [r"^$"]
args.exclude_exact = args.exclude_exact or []
torch._functorch.config.enable_autograd_cache = False
if args.inductor:
assert args.backend is None
args.backend = "inductor"
@ -3532,6 +3483,8 @@ def run(runner, args, original_dir=None):
if args.dynamic_shapes:
if not args.dynamic_batch_only:
torch._dynamo.config.assume_static_by_default = False
if args.compiled_autograd:
torch._dynamo.config.compiled_autograd = True
if args.propagate_real_tensors:
# TODO: Separate flag for data dependent
torch._dynamo.config.capture_scalar_outputs = True

View File

@ -3924,6 +3924,71 @@ class CompiledAutograd1(torch.nn.Module):
fn, count=[1, 5], compiler_fn=make_compiler_fn(fullgraph=False)
)
def test_dont_dce_side_effects(self):
class SideEffectfulBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, gO):
torch.randn(10, 10)
return gO
x = torch.randn(10, 10, requires_grad=True)
@torch.compile(backend="aot_eager")
def fn(x):
return SideEffectfulBackward.apply(x).sum()
gm = None
def extract(ca_gm):
nonlocal gm
gm = ca_gm
return ca_gm
with compiled_autograd._enable(extract):
fn(x).backward()
self.assertTrue("aten.randn" in str(gm))
def test_aot_bwd_gm_runnable(self):
# This test ensures that the bw_module saved in
# CompiledFunction._lazy_backward_info is executable,
# by ensuring post grad passes have not ran on it.
post_grad_graphs = []
def post_grad_pass(graph):
nonlocal post_grad_graphs
post_grad_graphs.append(graph)
return graph
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
# forces symints to be saved for backward
# and forces aot compilation of the backward
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(y, 1)
@torch.compile
def fn(x, y):
return torch.matmul(x, y).sum()
with inductor_config.patch(post_grad_custom_post_pass=post_grad_pass):
loss = fn(x, y)
self.assertEqual(len(post_grad_graphs), 2) # 1 fwd and 1 bwd
self.assertTrue(loss.grad_fn.name(), "CompiledFunctionBackward")
self.assertIsNot(
post_grad_graphs[1],
loss.grad_fn._forward_cls._lazy_backward_info.bw_module.graph,
)
with compiled_autograd._enable(lambda gm: gm):
loss.backward()
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent

View File

@ -760,16 +760,11 @@ class AutogradCompilerInstance:
assert i == len(_graph_placeholders) - 1
def is_impure(node):
return (
node in unpack_nodes
or node.op == "placeholder"
or node.op == "output"
or (node.op == "call_function" and node.target in _impure_targets)
or (
node.op == "call_function"
and node.target in torch.fx.node._side_effectful_functions
)
)
if node in unpack_nodes or (
node.op == "call_function" and node.target in _impure_targets
):
return True
return node.is_impure()
before = len(self.fx_tracer.graph.nodes)
self.fx_tracer.graph.eliminate_dead_code(is_impure)

View File

@ -757,12 +757,15 @@ class OptimizeContext(_TorchDynamoContext):
)
if config.compiled_autograd:
_dynamic = self._dynamic
if _dynamic is None:
_dynamic = not torch._dynamo.config.assume_static_by_default
def call_compiled_autograd():
assert rebuild_ctx is not None
compiler_fn = rebuild_ctx()
ctx = torch._dynamo.compiled_autograd._enable(
compiler_fn, dynamic=self._dynamic
compiler_fn, dynamic=_dynamic
)
ctx.__enter__()
return functools.partial(ctx.__exit__, None, None, None)

View File

@ -77,7 +77,7 @@ def radians(x):
def accumulate_grad(x, new_grad):
if new_grad is None:
return
new_grad = torch.clone(new_grad)
new_grad = torch.clone(new_grad).as_strided(x.size(), x.stride())
if x.grad is None:
x.grad = new_grad
else:

View File

@ -1184,8 +1184,11 @@ def aot_dispatch_autograd(
compiled_bw_func = None
if num_symints_saved_for_bw > 0:
try:
# backends may mutate the bw_module and leave
# it in an non-runnable state, so we lower it
# with a copy
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
copy.deepcopy(bw_module), placeholder_list
)
except Exception as e:
exc = e

View File

@ -8,6 +8,7 @@ This module defines runtime wrappers, which, based on previous analysis attempts
"""
import builtins
import collections
import copy
import itertools
import pprint
from contextlib import nullcontext
@ -2184,13 +2185,10 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
):
CompileEventLogger.compilation_metric(is_forward=False)
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
copy.deepcopy(bw_module), placeholder_list
)
# Maybe save cache entry
if try_save_cache_entry is not None:
# CompiledFunction.metadata
# CompiledFunction.maybe_subclass_metadata
# bw_module
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,