Compare commits

...

2 Commits

Author SHA1 Message Date
ed5a98a5e2 temp hacks to remove functionalization + get bitwise equivalence with aot_eager on llama3
ghstack-source-id: 2f88fd467b70feea229a7148ea540f1eb649a752
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
2025-10-07 12:47:32 -07:00
7c176001e4 Use codegen for the boxed interpreters
Authored with claude code.  The arg parsing is kind of horrible, open
to more suggestions.

Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 16e6b40adf7178b2339996a4d5812a5104052540
Pull-Request: https://github.com/pytorch/pytorch/pull/164573
2025-10-07 08:11:44 -07:00
7 changed files with 137 additions and 34 deletions

View File

@ -153,8 +153,17 @@ def torchscript(
def boxed_nop(
fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> Callable[..., Any]:
from torch.fx.graph import _BoxedCodeGen
# Set the graph to use boxed codegen
fx_g.graph.set_codegen(_BoxedCodeGen())
fx_g.recompile()
# Wrap the forward method in a function so we can set _boxed_call attribute
forward_fn = fx_g.forward
def run(args: Any) -> Any:
return torch.fx.Interpreter(fx_g).boxed_run(args)
return forward_fn(args)
run._boxed_call = True # type: ignore[attr-defined]
return run
@ -166,9 +175,18 @@ def boxed_nop_with_mode(
*,
mode: torch.overrides.TorchFunctionMode,
) -> Callable[..., Any]:
from torch.fx.graph import _BoxedCodeGen
# Set the graph to use boxed codegen
fx_g.graph.set_codegen(_BoxedCodeGen())
fx_g.recompile()
# Create a wrapper that runs with the mode
forward_fn = fx_g.forward
def run(args: Any) -> Any:
with mode:
return torch.fx.Interpreter(fx_g).boxed_run(args)
return forward_fn(args)
run._boxed_call = True # type: ignore[attr-defined]
return run
@ -179,9 +197,18 @@ def fake_crossref_boxed_nop(
example_inputs: list[torch.Tensor],
ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None,
) -> Callable[..., Any]:
from torch.fx.graph import _BoxedCodeGen
# Set the graph to use boxed codegen
fx_g.graph.set_codegen(_BoxedCodeGen())
fx_g.recompile()
# Create a wrapper that runs with the mode
forward_fn = fx_g.forward
def run(args: Any) -> Any:
with torch._subclasses.CrossRefFakeMode(ignore_op_fn):
return torch.fx.Interpreter(fx_g).boxed_run(args)
return forward_fn(args)
run._boxed_call = True # type: ignore[attr-defined]
return run

View File

@ -13,23 +13,16 @@ import torch.utils.dlpack
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torchgen.utils import dataclass_repr
from .. import config
from .descriptors import AOTInput, BackwardTokenAOTInput
from .functional_utils import (
assert_functional_graph,
propagate_input_mutation_stacktraces,
)
from .graph_capture_wrappers import (
aot_dispatch_subclass,
create_functionalized_fn,
create_joint,
fn_input_mutations_to_outputs,
fn_prepped_for_autograd,
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .utils import (
@ -72,12 +65,12 @@ def _create_graph(
with (
enable_python_dispatcher(),
FunctionalTensorMode(
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
),
# FunctionalTensorMode(
# pre_dispatch=aot_config.pre_dispatch,
# export=aot_config.is_export,
# # Allow token discovery for joint fn tracing as tokens can be used in backward.
# _allow_token_discovery=True,
# ),
):
fx_g = make_fx(
inner_f,
@ -162,6 +155,11 @@ def aot_dispatch_base_graph(
keep_data_input_mutations=aot_config.keep_inference_input_mutations,
)
updated_flat_args, updated_flat_args_descs = (
flat_args,
flat_args_descs,
)
"""
fn_to_trace, updated_flat_args, updated_flat_args_descs = create_functionalized_fn(
fn_to_trace,
flat_args,
@ -170,6 +168,7 @@ def aot_dispatch_base_graph(
aot_config=aot_config,
trace_joint=False,
)
"""
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
# fn_input_mutations_to_outputs and create_functionalized_fn
@ -188,6 +187,7 @@ def aot_dispatch_base_graph(
fw_only=flat_fn,
)
"""
(
fn_to_trace,
updated_flat_args_subclasses_desugared,
@ -199,6 +199,7 @@ def aot_dispatch_base_graph(
meta=fw_metadata,
trace_joint=False,
)
"""
aot_graphs_log.debug(
"aot_config id: %s, fw_metadata=%s,subclass_metadata=%s",
@ -265,12 +266,12 @@ def aot_dispatch_base_graph(
# As long as we opted to remove input mutations, then
# there should be *NO* mutating ops in the graph at this point.
copy_count = assert_functional_graph(fw_module.graph)
fw_module.graph.eliminate_dead_code()
fw_module.recompile()
# copy_count = assert_functional_graph(fw_module.graph)
# fw_module.graph.eliminate_dead_code()
# fw_module.recompile()
copy_count2 = assert_functional_graph(fw_module.graph)
propagate_input_mutation_stacktraces(fw_module.graph)
# copy_count2 = assert_functional_graph(fw_module.graph)
# propagate_input_mutation_stacktraces(fw_module.graph)
# See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
@ -283,7 +284,7 @@ def aot_dispatch_base_graph(
saved_updated_flat_args_subclasses_desugared_descs[num_tokens:]
)
assert copy_count == copy_count2
# assert copy_count == copy_count2
if aot_config.enable_log:
aot_graphs_log.info(
@ -373,8 +374,11 @@ def aot_dispatch_autograd_graph(
joint_fn_to_trace = create_joint(
fn_prepared_for_autograd, flat_args_descs, aot_config=aot_config
)
joint_fn_handle = joint_fn_to_trace.handle
# joint_fn_handle = joint_fn_to_trace.handle
updated_joint_inputs, updated_joint_inputs_descs = joint_inputs, joint_inputs_descs
"""
joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = (
create_functionalized_fn(
joint_fn_to_trace,
@ -386,6 +390,7 @@ def aot_dispatch_autograd_graph(
joint_fn_handle=joint_fn_handle,
)
)
"""
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
# fn_input_mutations_to_outputs and create_functionalized_fn
@ -403,6 +408,7 @@ def aot_dispatch_autograd_graph(
updated_joint_inputs = subclass_tracing_info.plain_tensor_args
updated_joint_inputs_descs = subclass_tracing_info.plain_tensor_args_descs
"""
(joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = (
handle_effect_tokens_fn(
joint_fn_to_trace,
@ -412,6 +418,7 @@ def aot_dispatch_autograd_graph(
trace_joint=True,
)
)
"""
# When we call _create_graph, this may mutate the metadata of joint
# inputs. But callers are expecting to get the original joint inputs. So
@ -441,13 +448,13 @@ def aot_dispatch_autograd_graph(
)
# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
# /assert_functional_graph(fx_g.graph)
# Redundant with the check above, but worth having in case tracing introduced
# a fake tensor. Unlikely.
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
fx_g.graph.eliminate_dead_code()
# fx_g.graph.eliminate_dead_code()
copy_fwd_metadata_to_bw_nodes(fx_g)
fx_g.recompile()

View File

@ -16,6 +16,8 @@ from collections.abc import Callable
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import Any, cast, Optional, TypeVar, Union
from typing import Any, Callable, cast, Optional, TypeVar, Union
from typing import Any, Optional, TypeVar, Union
from unittest.mock import patch
import torch
@ -65,7 +67,6 @@ from .functional_utils import (
has_data_mutation,
has_metadata_mutation,
is_fun,
sync_functional_tensor,
to_fun,
was_inductor_storage_resized,
)
@ -243,7 +244,7 @@ def fn_prepped_for_autograd(
for arg in args_maybe_cloned:
if not isinstance(arg, Tensor):
continue
sync_functional_tensor(arg)
# sync_functional_tensor(arg)
return (fw_outs_to_return, out_grad_mask), (
fw_outs_to_return_descs,
@ -430,9 +431,14 @@ def create_joint(
with torch.autograd.detect_anomaly(check_nan=False):
return inner_fn(primals, tangents)
inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
# inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
return cast(JointTraceFn, inner_fn_with_anomaly) # deal with 'handle' property
# TODO: only need to skip this when turning off functionalization
# inner_fn_with_anomaly.handle = joint_fn_handle # type: ignore[attr-defined]
def joint_helper(primals, tangents):
return inner_fn_with_anomaly(primals, tangents)
return joint_helper
def create_functionalized_rng_ops_wrapper(

View File

@ -148,6 +148,9 @@ class OperatorBase:
"Please register a mode for the DispatchKey.Python key instead."
)
if k == DispatchKey.CompositeImplicitAutograd or k == DispatchKey.Autograd:
if torch._C._dispatch_has_kernel(self.name()) and torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k):
return fn
if k in self.py_kernels:
raise RuntimeError(
f"Trying to override a python impl for {k} on operator {self.name()}"

View File

@ -901,12 +901,29 @@ def proxy_call(
_maybe_record_pointwise_barrier(func, proxy_mode)
return r
def should_decompose(func, flat_args):
has_backend_registration = False
for a in flat_args:
if isinstance(a, torch.Tensor):
backend_key = torch._C._dispatch_key_for_device(a.device.type)
has_backend_registration = func.has_kernel_for_dispatch_key(backend_key)
# in theory we should take all backend keys and take the highest priority one
# to properly mimic the disaptcher,
# this just grabs the first tensor and takes its device key
break
return not has_backend_registration
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if not pre_dispatch and func not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]:
if (
not pre_dispatch
and func
not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]
and should_decompose(func, flat_args_kwargs)
):
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:

View File

@ -848,6 +848,44 @@ class CodeGen:
# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
# Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
# 3. We currently can't register the pytree imports with `add_global` - not sure why.
class _BoxedCodeGen(CodeGen):
"""
CodeGen subclass that generates code using the "boxed" calling convention.
The boxed calling convention takes a single list argument and clears it
after extracting the arguments, which allows for early deallocation of
input tensors.
"""
def gen_fn_def(
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
):
"""
Generate function definition for boxed calling convention.
Instead of taking individual arguments, the generated function takes
a single 'args_list' parameter, extracts placeholder values from it,
and clears the list.
"""
# Generate the function signature with args_list parameter
fn_def = f"def {self._func_name}(self, args_list){maybe_return_annotation}:"
if free_vars:
# This is horribly manual but we don't get the "raw" free vars
# without a bigger refactor.
placeholder_vars = [
v.split(":")[0].split("=")[0].strip() for v in free_vars if v != "self"
]
if placeholder_vars:
fn_def += "\n args_iter = iter(args_list)"
for var in placeholder_vars:
fn_def += f"\n {var} = next(args_iter)"
fn_def += "\n args_list.clear()"
return fn_def
class _PyTreeCodeGen(CodeGen):
def __init__(self, pytree_info: _PyTreeInfo):
super().__init__()

View File

@ -18,6 +18,7 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_import
from ._compatibility import compatibility
from .graph import (
_BoxedCodeGen,
_custom_builtins,
_is_from_torch,
_override_sym_repr,
@ -554,6 +555,10 @@ class GraphModule(torch.nn.Module):
# Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
__jit_unused_properties__ = ["graph"]
@property
def _boxed_call(self) -> bool:
return isinstance(self._graph._codegen, _BoxedCodeGen)
@property
def graph(self) -> Graph:
"""