Compare commits

...

5 Commits

Author SHA1 Message Date
fbb98e20c0 Add buffer static input tests to cudagraph trees
ghstack-source-id: 35d65de29900376b26e41486cc201b7d1249249e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130402
2024-07-12 12:20:14 -07:00
c104740ee7 Propagate buffer and parameter indices through AOT
ghstack-source-id: b4f05055d8d497fdc7eb6e687aea27d9b333cd65
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130393
2024-07-12 12:20:13 -07:00
5b03e3f7e1 Remove static param counting if inlining NN modules
ghstack-source-id: 52f6705e7c91c39433c5077e5e6507772bd381a3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130503
2024-07-10 15:59:46 -07:00
4480745f1b Update mark_static_address for inlining NN modules
ghstack-source-id: b2fc3d21b821e1c95727ddfc0db1c412c968d25a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130392
2024-07-10 11:52:09 -07:00
fa278d234e Mark nn_module params and buffers as static in dynamo
ghstack-source-id: 7275ee83f62bb3f33b16bfcb83cfa239454e16af
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130391
2024-07-10 11:52:09 -07:00
14 changed files with 321 additions and 40 deletions

View File

@ -368,13 +368,43 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.op_count, 3)
def _test_mark_static_address(self, guarded):
# This test verifies that dynamo properly marks inputs as static
# when using the mark_static_address API.
# On 1st compile, we expect the input to be marked as static, with guarded
# set depending on the `guarded` flag.
# On 2nd compile, we expect the input to be unmarked
# if inlining NN modules, we expect metadata to be present on the tensor, indicating
# the static address type of the input
# if not inlining NN modules, we expect the tensor to be present in the buffers attribute
# of the graph.
compiles_with_buffers = 0
compiles = 0
def debug_compiler(gm, _):
nonlocal compiles_with_buffers
nonlocal compiles
compiles_with_buffers += len(gm._buffers) > 0
if torch._dynamo.config.inline_inbuilt_nn_modules:
input_node = [
n
for n in gm.graph.nodes
if n.op == "placeholder" and n.name == "l_x_"
]
self.assertEqual(len(input_node), 1)
input_node = input_node[0]
if compiles == 0:
self.assertEqual(
input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
"guarded" if guarded else "unguarded",
)
elif compiles == 1:
self.assertFalse(
"_dynamo_statc_input_type" in input_node.meta["tensor_dict"]
)
else:
raise RuntimeError(f"Unexpected number of compiles: {compiles}")
else:
compiles_with_buffers += len(gm._buffers) > 0
compiles += 1
return gm
@ -387,7 +417,8 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
torch._dynamo.mark_static_address(inp, guard=guarded)
fn(inp)
self.assertEqual(compiles_with_buffers, 1)
if not torch._dynamo.config.inline_inbuilt_nn_modules:
self.assertEqual(compiles_with_buffers, 1)
inp2 = torch.ones(2)
@ -395,13 +426,22 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
# since it was not marked static, compiles with buffers
# should not be incremented
fn(inp2)
self.assertEqual(compiles_with_buffers, 1)
if not torch._dynamo.config.inline_inbuilt_nn_modules:
self.assertEqual(compiles_with_buffers, 1)
self.assertEqual(compiles, 2 if guarded else 1)
def test_mark_static_address_guarded(self):
with torch._dynamo.config.patch("inline_inbuilt_nn_modules", True):
self._test_mark_static_address(guarded=True)
self._test_mark_static_address(guarded=True)
def test_mark_static_address_unguarded(self):
with torch._dynamo.config.patch("inline_inbuilt_nn_modules", True):
self._test_mark_static_address(guarded=False)
self._test_mark_static_address(guarded=False)
def test_class_methods(self):

View File

@ -2648,6 +2648,91 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
expected = mod(x)
self.assertEqual(actual, expected)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_mark_static_previously_seen_tensor(self):
# This test verifies that dynamo will mark
# the buffers/params of a module as static
# even if this param was previously seen
# (ex. as a different input)
num_compiles = 0
def debug_compiler(gm, _):
nonlocal num_compiles
num_compiles += 1
input_nodes = [
n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
]
self.assertGreater(len(input_nodes), 0)
for input_node in input_nodes:
self.assertEqual(
input_node.meta["tensor_dict"]["_dynamo_static_input_type"],
"unguarded",
)
return gm
class TestModule(torch.nn.Module):
def __init__(self, buf) -> None:
super().__init__()
self.register_buffer("buf", buf)
def forward(self, x):
return self.buf * x
@torch._dynamo.optimize(backend=debug_compiler)
def fn(x, b, mod):
z = b + 1
return z * mod(x)
buf = torch.ones(2, 2)
inp = torch.ones(2)
mod = TestModule(buf)
fn(inp, buf, mod)
self.assertEqual(num_compiles, 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@torch._inductor.config.patch("freezing", True)
@torch.no_grad()
def test_mark_static_with_freezing(self):
# This test verifies that dynamo will
# add buffers/params as attributes of the
# graph w/ guards if freezing is enabled
num_compiles = 0
def debug_compiler(gm, _):
nonlocal num_compiles
num_compiles += 1
input_nodes = [
n for n in gm.graph.nodes if n.op == "placeholder" and n.name == "l_b_"
]
self.assertEqual(len(input_nodes), 0)
self.assertEqual(len(list(gm.buffers())), 1)
return gm
class TestModule(torch.nn.Module):
def __init__(self, buf) -> None:
super().__init__()
self.register_buffer("buf", buf)
def forward(self, x):
return self.buf * x
@torch._dynamo.optimize(backend=debug_compiler)
def fn(x, mod):
return mod(x)
buf = torch.ones(2, 2)
inp = torch.ones(2)
mod = TestModule(buf)
fn(inp, mod)
self.assertEqual(num_compiles, 1)
mod.buf = torch.rand_like(buf)
fn(inp, mod)
self.assertEqual(num_compiles, 2)
def test_no_guard_on_torch_nn_modules(self):
# https://github.com/pytorch/pytorch/issues/110048

View File

@ -1526,6 +1526,45 @@ class GraphModule(torch.nn.Module):
out_test = compiled_f(view)
self.assertEqual(out_ref, out_test)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_mark_static_with_subclass_desugaring(self):
from typing import Any, Callable, Dict, List, Optional
from torch._dynamo.decorators import mark_static_address
from torch._inductor.compile_fx import compile_fx
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.utils import BoxedBool
x_inner = torch.ones(4)
x = TwoTensor(x_inner, x_inner)
mark_static_address(x, guard=False)
def inner_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
is_inference: bool = False,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
user_visible_outputs: Optional[Dict[str, None]] = None,
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
):
self.assertEqual(static_input_idxs, [1, 2])
return gm
compiler = functools.partial(compile_fx, inner_compile=inner_compile)
@torch.compile(backend=compiler)
def fn(t0, t1, t2):
return t0 + t1 + t2 + 2
fn(torch.ones(4), x, torch.ones(4))
instantiate_parametrized_tests(SubclassTests)

View File

@ -1822,7 +1822,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
def _module_test(self, mod):
def _module_test(self, mod, name="weight", param_wrapping=True):
with torch.device("cuda"):
def fn(x, mod):
@ -1845,11 +1845,14 @@ if HAS_CUDA and not TEST_WITH_ASAN:
self.assertEqual(exp_grad, compiled_grad)
run_test()
old = mod.weight.data
mod.weight.data = torch.rand_like(mod.weight.data)
old_attr = getattr(mod, name)
modified_attr = torch.rand_like(old_attr)
if param_wrapping:
modified_attr = torch.nn.Parameter(modified_attr)
setattr(mod, name, modified_attr)
run_test()
# Run original version to verify we reuse the other recording
mod.weight.data = old
setattr(mod, name, old_attr)
run_test()
# Fwd + bwd graphs for each version of the function => 4 graphs
@ -1876,6 +1879,18 @@ if HAS_CUDA and not TEST_WITH_ASAN:
# Note: Linear is a builtin module so we enable that config setting above
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_builtin_module_buffers(self):
# Verify that we don't recompile when changing the buffer of a builtin module
# and that we record another cudagraph
self._module_test(
torch.nn.BatchNorm1d(2, device="cuda"),
name="running_mean",
param_wrapping=False,
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@ -1894,6 +1909,30 @@ if HAS_CUDA and not TEST_WITH_ASAN:
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_custom_module_buffer(self):
# Test that we can correctly dispatch multiple graphs
# if buffers of a custom module change
class TestModule(torch.nn.Module):
def __init__(self, param, buf) -> None:
super().__init__()
self.weight = param
self.register_buffer("buf", buf)
def forward(self, x):
return x * self.weight + self.buf
self._module_test(
TestModule(
torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
torch.rand([2, 2], device="cuda"),
),
name="buf",
param_wrapping=False,
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)

View File

@ -2834,3 +2834,7 @@ def _disable_saved_tensors_hooks_during_tracing():
yield
finally:
torch._C._autograd._saved_tensors_hooks_set_tracing(prior)
def is_parameter_freezing():
return torch._inductor.config.freezing and not torch.is_grad_enabled()

View File

@ -47,6 +47,7 @@ from torch.fx.experimental.symbolic_shapes import (
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import TensorWeakRef
from .. import config, mutation_guard, replay_record, trace_rules
from ..device_interface import get_registered_device_interfaces
@ -90,6 +91,7 @@ from ..utils import (
is_function_or_wrapper,
is_lru_cache_wrapped_function,
is_namedtuple,
is_parameter_freezing,
is_typing,
is_utils_checkpoint,
is_wrapper_or_member_descriptor,
@ -1163,6 +1165,19 @@ class VariableBuilder:
else:
return RangeVariable(items, source=self.source)
def mark_static_input(self, value: torch.Tensor, guard: bool):
from ..decorators import mark_static_address
mark_static_address(value, guard=guard)
# Check if we've seen this tensor before and update graph metadata if needed
# As long as this runs before AOT this is sound
if value in self.tx.output.side_effects:
var = self.tx.output.side_effects[value]
var.proxy.node.meta["tensor_dict"][
"_dynamo_static_input_type"
] = value._dynamo_static_input_type
def wrap_module(self, value: torch.nn.Module):
from ..eval_frame import OptimizedModule
@ -1216,23 +1231,19 @@ class VariableBuilder:
elif mutation_guard.is_dynamic_nn_module(value, self.tx.export):
# created dynamically, don't specialize on it
self.install_guards(GuardBuilder.TYPE_MATCH)
if (
torch._dynamo.config.inline_inbuilt_nn_modules
and torch._inductor.config.freezing
and not torch.is_grad_enabled()
):
from ..decorators import mark_static_address
if torch._dynamo.config.inline_inbuilt_nn_modules:
freezing = is_parameter_freezing()
for p in value.parameters():
mark_static_address(p)
self.mark_static_input(p, guard=freezing)
for b in value.buffers():
mark_static_address(b)
self.mark_static_input(b, guard=freezing)
# we need to add the module to tracing context
# in order to allow its params to get invalidated
# this will get cleaned up once compile ends
self.tx.output.nn_modules[self.name] = value
if freezing:
# we need to add the module to tracing context
# in order to allow its params to get invalidated
# this will get cleaned up once compile ends
self.tx.output.nn_modules[self.name] = value
result = UnspecializedNNModuleVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
@ -1298,9 +1309,21 @@ class VariableBuilder:
# it would have already been wrapped
assert value not in self.tx.output.side_effects
is_static_input = get_static_address_type(value) is not None
if (
source.guard_source().is_nn_module()
or get_static_address_type(value) is not None
config.inline_inbuilt_nn_modules
and not is_static_input
and isinstance(value, torch.nn.Parameter)
):
self.mark_static_input(value, guard=False)
make_graph_attribute = is_static_input and (
not config.inline_inbuilt_nn_modules or is_parameter_freezing()
)
if (
source.guard_source().is_nn_module() or make_graph_attribute
) and not source.guard_source().is_fsdp_module():
self.assert_not_wrapped_by_this_graph(value)
return self.tx.output.register_attr_or_module(
@ -1353,6 +1376,9 @@ class VariableBuilder:
if is_duplicate_tensor:
return self.tx.output.input_source_to_var[source]
if get_static_address_type(value) == "guarded":
self.install_guards(GuardBuilder.ID_MATCH)
# By this point, we should have deduplicated all tensors
self.assert_not_wrapped_by_this_graph(value)
@ -1413,9 +1439,11 @@ class VariableBuilder:
self.install_guards(
functools.partial(
guard_type,
value=value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value),
value=(
value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value)
),
)
)

View File

@ -11,7 +11,7 @@ a functionalized version of the graph under compilation.
import collections
import logging
from functools import wraps
from typing import Callable, DefaultDict, Dict, List
from typing import Callable, DefaultDict, Dict, List, Optional
import torch
import torch.utils._pytree as pytree
@ -25,6 +25,7 @@ from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
from .functional_utils import (
are_all_mutations_hidden_from_autograd,
are_all_mutations_under_no_grad_or_inference_mode,
@ -124,6 +125,8 @@ def run_functionalized_fw_and_collect_metadata(
keep_input_mutations: bool,
# TODO: refactor to kill this flag
is_train: bool = False,
# Note: this is guaranteed to be set when running under dynamo
static_input_indices: Optional[List[int]] = None,
pre_dispatch: bool = False,
) -> Callable[..., ViewAndMutationMeta]:
memo: Dict[Tensor, Tensor] = {}
@ -666,17 +669,15 @@ from a multi-output view call"
)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
if (
torch._dynamo.config.inline_inbuilt_nn_modules
or torch._dynamo.compiled_autograd.in_compiled_autograd_region
):
static_parameter_input_indices = [
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
passed_indices = set(static_input_indices)
static_input_indices = [
i
for i, arg in enumerate(flat_args)
if isinstance(arg, torch.nn.Parameter)
if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
]
else:
static_parameter_input_indices = []
f_mutated_inputs = [
inp
@ -729,7 +730,7 @@ from a multi-output view call"
subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation,
static_parameter_indices=static_parameter_input_indices,
static_input_indices=static_input_indices,
tokens=mode._tokens,
)
return metadata

View File

@ -905,6 +905,7 @@ class AOTDedupeWrapper(CompilerWrapper):
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*deduped_flat_args)
@ -1094,6 +1095,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*flat_args_with_synthetic_bases)

View File

@ -328,7 +328,7 @@ class ViewAndMutationMeta:
deterministic: Optional[bool] = None
# Keeps track of which input indices store parameters (which we will treat as static)
static_parameter_indices: List[int] = field(default_factory=list)
static_input_indices: List[int] = field(default_factory=list)
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
@ -802,6 +802,7 @@ class AOTConfig:
no_tangents: bool = False
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
static_input_indices: Optional[List[int]] = None
inference_compiler: Optional[Callable] = None
enable_log: bool = True
# this is always false outside of export.

View File

@ -131,6 +131,24 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
return unwrapped_args
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
static_input_indices = set(static_input_indices)
new_ind = 0
remapped_static_indices = []
for i, arg in enumerate(wrapped_args):
num_indices = 1
if is_traceable_wrapper_subclass(arg):
num_indices = len(get_plain_tensors(arg))
for _ in range(num_indices):
if i in static_input_indices:
remapped_static_indices.append(new_ind)
new_ind += 1
return remapped_static_indices
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
def wrap_tensor_subclasses(

View File

@ -53,6 +53,7 @@ from .schemas import (
)
from .subclass_utils import (
create_subclass_meta,
remap_unwrapped_subclass_arg_indices,
requires_subclass_dispatch,
unwrap_tensor_subclasses,
wrap_tensor_subclasses_maybe_joint,
@ -702,6 +703,9 @@ def aot_dispatch_subclass(
args_unwrapped = unwrap_tensor_subclasses(
args, is_joint_structure=is_joint_structure
)
remapped_static_indices = remap_unwrapped_subclass_arg_indices(
args, meta.static_input_indices
)
if is_joint_structure:
primals_unwrapped = args_unwrapped[0]
@ -729,6 +733,7 @@ def aot_dispatch_subclass(
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
meta_updated = run_functionalized_fw_and_collect_metadata(
metadata_fn,
static_input_indices=remapped_static_indices,
keep_input_mutations=meta.keep_input_mutations,
is_train=meta.is_train,
)(*primals_unwrapped)

View File

@ -20,6 +20,7 @@ from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import config
from ._aot_autograd.autograd_cache import ( # noqa: F401
AOTAutogradCache,
@ -583,6 +584,7 @@ def create_aot_dispatcher_function(
with ctx:
fw_metadata = run_functionalized_fw_and_collect_metadata(
flat_fn,
static_input_indices=aot_config.static_input_indices,
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=needs_autograd,
pre_dispatch=aot_config.pre_dispatch,
@ -618,6 +620,7 @@ def create_aot_dispatcher_function(
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=False,
pre_dispatch=aot_config.pre_dispatch,
static_input_indices=aot_config.static_input_indices,
)(*fake_flat_args)
else:
fw_metadata = ViewAndMutationMeta(
@ -631,7 +634,7 @@ def create_aot_dispatcher_function(
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
is_train=False,
tokens=fw_metadata.tokens,
static_parameter_indices=fw_metadata.static_parameter_indices,
static_input_indices=fw_metadata.static_input_indices,
)
if fw_metadata.num_intermediate_bases > 0:
@ -936,9 +939,10 @@ def aot_module_simplified(
# Next, the input args
full_args.extend(args)
static_input_indices = []
if hasattr(mod, "graph"):
# Non dynamo entrypoints can get to here...
for node in mod.graph.find_nodes(op="placeholder"):
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
if hasattr(node, "_dynamo_source"):
# ... but not here!
if aot_autograd_arg_pos_to_source is None:
@ -948,6 +952,11 @@ def aot_module_simplified(
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
"_dynamo_static_input_type", None
):
static_input_indices.append(pos)
if aot_autograd_arg_pos_to_source is not None:
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
@ -968,6 +977,7 @@ def aot_module_simplified(
keep_inference_input_mutations=keep_inference_input_mutations,
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
static_input_indices=static_input_indices,
is_export=False,
no_tangents=False,
cache_key=None,

View File

@ -137,7 +137,7 @@ def get_static_input_idxs(num_fixed):
if not context or not context.fw_metadata:
return fixed
return fixed + context.fw_metadata.static_parameter_indices
return fixed + context.fw_metadata.static_input_indices
@functools.lru_cache(None)
@ -1254,7 +1254,7 @@ def fw_compiler_freezing(
params_flat[i] = None
if tracing_context.fw_metadata:
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
static_input_idxs += tracing_context.fw_metadata.static_input_indices
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(

View File

@ -1538,6 +1538,15 @@ def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int)
num_rng_seed_offset_inputs = (
2 if torch._functorch.config.functionalize_rng_ops else 0
)
# AOT won't lift any parameters if we're inlining NN Modules
# however desugaring subclasses will still add arguments
# resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
if (
torch._dynamo.config.inline_inbuilt_nn_modules
and not torch._dynamo.utils.is_parameter_freezing()
):
return 0
return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs