Compare commits

...

2 Commits

Author SHA1 Message Date
6409ef12a3 [Reland] Add buffer static input tests to cudagraph trees
ghstack-source-id: 29657f4716ee2412485cd100f46aa7ba69eb51bb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130402
2024-07-16 11:06:57 -07:00
89d4d16256 [Reland] Propagate buffer and parameter indices through AOT
ghstack-source-id: 99b0f3442d5f122c95b5271e790f35d82774ec30
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130393
2024-07-16 11:06:57 -07:00
9 changed files with 134 additions and 19 deletions

View File

@ -1526,6 +1526,45 @@ class GraphModule(torch.nn.Module):
out_test = compiled_f(view)
self.assertEqual(out_ref, out_test)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_mark_static_with_subclass_desugaring(self):
from typing import Any, Callable, Dict, List, Optional
from torch._dynamo.decorators import mark_static_address
from torch._inductor.compile_fx import compile_fx
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.utils import BoxedBool
x_inner = torch.ones(4)
x = TwoTensor(x_inner, x_inner)
mark_static_address(x, guard=False)
def inner_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
is_inference: bool = False,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
user_visible_outputs: Optional[Dict[str, None]] = None,
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
):
self.assertEqual(static_input_idxs, [1, 2])
return gm
compiler = functools.partial(compile_fx, inner_compile=inner_compile)
@torch.compile(backend=compiler)
def fn(t0, t1, t2):
return t0 + t1 + t2 + 2
fn(torch.ones(4), x, torch.ones(4))
instantiate_parametrized_tests(SubclassTests)

View File

@ -1855,7 +1855,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
def _module_test(self, mod):
def _module_test(self, mod, name="weight", param_wrapping=True):
with torch.device("cuda"):
def fn(x, mod):
@ -1878,11 +1878,14 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertEqual(exp_grad, compiled_grad)
run_test()
old = mod.weight.data
mod.weight.data = torch.rand_like(mod.weight.data)
old_attr = getattr(mod, name)
modified_attr = torch.rand_like(old_attr)
if param_wrapping:
modified_attr = torch.nn.Parameter(modified_attr)
setattr(mod, name, modified_attr)
run_test()
# Run original version to verify we reuse the other recording
mod.weight.data = old
setattr(mod, name, old_attr)
run_test()
# Fwd + bwd graphs for each version of the function => 4 graphs
@ -1907,6 +1910,18 @@ if HAS_CUDA and not TEST_WITH_ASAN:
# Note: Linear is a builtin module so we enable that config setting above
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_builtin_module_buffers(self):
# Verify that we don't recompile when changing the buffer of a builtin module
# and that we record another cudagraph
self._module_test(
torch.nn.BatchNorm1d(2, device="cuda"),
name="running_mean",
param_wrapping=False,
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_custom_module(self):
@ -1924,6 +1939,30 @@ if HAS_CUDA and not TEST_WITH_ASAN:
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_custom_module_buffer(self):
# Test that we can correctly dispatch multiple graphs
# if buffers of a custom module change
class TestModule(torch.nn.Module):
def __init__(self, param, buf) -> None:
super().__init__()
self.weight = param
self.register_buffer("buf", buf)
def forward(self, x):
return x * self.weight + self.buf
self._module_test(
TestModule(
torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
torch.rand([2, 2], device="cuda"),
),
name="buf",
param_wrapping=False,
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_child_node(self):

View File

@ -11,7 +11,7 @@ a functionalized version of the graph under compilation.
import collections
import logging
from functools import wraps
from typing import Callable, DefaultDict, Dict, List
from typing import Callable, DefaultDict, Dict, List, Optional
import torch
import torch.utils._pytree as pytree
@ -25,6 +25,7 @@ from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
from .functional_utils import (
are_all_mutations_hidden_from_autograd,
are_all_mutations_under_no_grad_or_inference_mode,
@ -124,6 +125,8 @@ def run_functionalized_fw_and_collect_metadata(
keep_input_mutations: bool,
# TODO: refactor to kill this flag
is_train: bool = False,
# Note: this is guaranteed to be set when running under dynamo
static_input_indices: Optional[List[int]] = None,
pre_dispatch: bool = False,
) -> Callable[..., ViewAndMutationMeta]:
memo: Dict[Tensor, Tensor] = {}
@ -666,17 +669,15 @@ from a multi-output view call"
)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
if (
torch._dynamo.config.inline_inbuilt_nn_modules
or torch._dynamo.compiled_autograd.in_compiled_autograd_region
):
static_parameter_input_indices = [
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
passed_indices = set(static_input_indices)
static_input_indices = [
i
for i, arg in enumerate(flat_args)
if isinstance(arg, torch.nn.Parameter)
if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
]
else:
static_parameter_input_indices = []
f_mutated_inputs = [
inp
@ -729,7 +730,7 @@ from a multi-output view call"
subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation,
static_parameter_indices=static_parameter_input_indices,
static_input_indices=static_input_indices,
tokens=mode._tokens,
)
return metadata

View File

@ -905,6 +905,7 @@ class AOTDedupeWrapper(CompilerWrapper):
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*deduped_flat_args)
@ -1094,6 +1095,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*flat_args_with_synthetic_bases)

View File

@ -329,7 +329,7 @@ class ViewAndMutationMeta:
deterministic: Optional[bool] = None
# Keeps track of which input indices store parameters (which we will treat as static)
static_parameter_indices: List[int] = field(default_factory=list)
static_input_indices: List[int] = field(default_factory=list)
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
@ -803,6 +803,7 @@ class AOTConfig:
no_tangents: bool = False
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
static_input_indices: Optional[List[int]] = None
inference_compiler: Optional[Callable] = None
enable_log: bool = True
# this is always false outside of export.

View File

@ -136,6 +136,24 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
return unwrapped_args
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
static_input_indices = set(static_input_indices)
new_ind = 0
remapped_static_indices = []
for i, arg in enumerate(wrapped_args):
num_indices = 1
if is_traceable_wrapper_subclass(arg):
num_indices = len(get_plain_tensors(typing.cast(Tensor, arg)))
for _ in range(num_indices):
if i in static_input_indices:
remapped_static_indices.append(new_ind)
new_ind += 1
return remapped_static_indices
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
def wrap_tensor_subclasses(

View File

@ -53,6 +53,7 @@ from .schemas import (
)
from .subclass_utils import (
create_subclass_meta,
remap_unwrapped_subclass_arg_indices,
requires_subclass_dispatch,
unwrap_tensor_subclasses,
wrap_tensor_subclasses_maybe_joint,
@ -702,6 +703,9 @@ def aot_dispatch_subclass(
args_unwrapped = unwrap_tensor_subclasses(
args, is_joint_structure=is_joint_structure
)
remapped_static_indices = remap_unwrapped_subclass_arg_indices(
args, meta.static_input_indices
)
if is_joint_structure:
primals_unwrapped = args_unwrapped[0]
@ -729,6 +733,7 @@ def aot_dispatch_subclass(
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
meta_updated = run_functionalized_fw_and_collect_metadata(
metadata_fn,
static_input_indices=remapped_static_indices,
keep_input_mutations=meta.keep_input_mutations,
is_train=meta.is_train,
)(*primals_unwrapped)

View File

@ -20,6 +20,7 @@ from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import config
from ._aot_autograd.autograd_cache import ( # noqa: F401
AOTAutogradCache,
@ -588,6 +589,7 @@ def create_aot_dispatcher_function(
with ctx:
fw_metadata = run_functionalized_fw_and_collect_metadata(
flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=needs_autograd,
pre_dispatch=aot_config.pre_dispatch,
@ -623,6 +625,7 @@ def create_aot_dispatcher_function(
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=False,
pre_dispatch=aot_config.pre_dispatch,
static_input_indices=aot_config.static_input_indices,
)(*fake_flat_args)
else:
fw_metadata = ViewAndMutationMeta(
@ -636,7 +639,7 @@ def create_aot_dispatcher_function(
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
is_train=False,
tokens=fw_metadata.tokens,
static_parameter_indices=fw_metadata.static_parameter_indices,
static_input_indices=fw_metadata.static_input_indices,
)
if fw_metadata.num_intermediate_bases > 0:
@ -941,9 +944,10 @@ def aot_module_simplified(
# Next, the input args
full_args.extend(args)
static_input_indices = []
if hasattr(mod, "graph"):
# Non dynamo entrypoints can get to here...
for node in mod.graph.find_nodes(op="placeholder"):
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
if hasattr(node, "_dynamo_source"):
# ... but not here!
if aot_autograd_arg_pos_to_source is None:
@ -953,6 +957,11 @@ def aot_module_simplified(
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
"_dynamo_static_input_type", None
):
static_input_indices.append(pos)
if aot_autograd_arg_pos_to_source is not None:
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
@ -973,6 +982,7 @@ def aot_module_simplified(
keep_inference_input_mutations=keep_inference_input_mutations,
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
static_input_indices=static_input_indices,
is_export=False,
no_tangents=False,
cache_key=None,

View File

@ -136,7 +136,7 @@ def get_static_input_idxs(num_fixed):
if not context or not context.fw_metadata:
return fixed
return fixed + context.fw_metadata.static_parameter_indices
return fixed + context.fw_metadata.static_input_indices
@functools.lru_cache(None)
@ -1246,7 +1246,7 @@ def fw_compiler_freezing(
params_flat[i] = None
if tracing_context.fw_metadata:
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
static_input_idxs += tracing_context.fw_metadata.static_input_indices
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(