Add support for param mutation under inference mode (#159661)

Summary:
In HF model rwkv, we have parameter mutation under inference mode which should be safe. This PR does multiple things to make sure it works:
1. We execute global autograd mutation while tracing so that we can actually trace through parameter inplace mutation
2. Add support for parameter mutation under inference mode in AOTAutograd
3. Add support for parameter mutation under inference mode in export.

Test Plan:
test

Rollback Plan:

Differential Revision: D79460136

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159661
Approved by: https://github.com/ydwu4
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2025-08-14 03:34:04 +00:00
committed by PyTorch MergeBot
parent 29d20d49f0
commit 194fcfcfbd
17 changed files with 305 additions and 31 deletions

View File

@ -326,6 +326,52 @@ class TestDynamismExpression(TestCase):
dynamic_shapes=dynamic_shapes,
)
def test_no_grad_param_inplace(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
def forward(self, x):
with torch.no_grad():
self.parameter.div_(2)
return x + self.parameter
foo_ep = Foo()
foo_eager = Foo()
ep = export(foo_ep, (torch.rand(4, 4),)).run_decompositions()
val = ep.graph_signature.parameters_to_mutate
self.assertExpectedInline(
str(ep.graph).strip(),
"""\
graph():
%p_parameter : [num_users=1] = placeholder[target=p_parameter]
%x : [num_users=1] = placeholder[target=x]
%div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%p_parameter, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %div), kwargs = {})
return (div, add)""",
)
self.assertTrue("div" in val.keys())
self.assertTrue("parameter" in val.values())
test_inp = torch.rand(4, 4)
res = foo_eager(test_inp)
# TODO We almost need to make the param mutation happen outside
# of the graph. Or wrap the param mutation in a no_grad HOP. Simply
# overriding gm.__call__ doesn't seem to work due to:
# 1. graph module does something weird to __call__ so it is not easy to override
# 2. We inspect module.forward to bind fake args when retracing
with self.assertRaisesRegex(RuntimeError, "leaf"):
res_export = ep.module()(torch.rand(4, 4))
with torch.no_grad():
res_export = ep.module()(test_inp)
self.assertTrue(torch.allclose(res, res_export))
def test_export_slice_unbacked_dim1(self):
class MySlice(torch.nn.Module):
def forward(self, x, seq_len):
@ -4000,6 +4046,17 @@ def forward(self, x):
inp = torch.randn(3, 3)
self.assertTrue(torch.allclose(ep.module()(inp)[0], inp + 1))
def test_set_grad_as_side_effect(self):
class Foo(torch.nn.Module):
def forward(self, x):
torch._C._set_grad_enabled(False)
return x.sum()
before = torch.is_grad_enabled()
ep = torch.export.export(Foo(), (torch.randn(4, 4),))
after = torch.is_grad_enabled()
self.assertEqual(before, after)
def test_derived_dim_out_of_order_simplified(self):
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1

View File

@ -280,6 +280,25 @@ def forward(self, x):
actual_out = loaded_ep.module()(*inp)
self.assertEqual(exp_out, actual_out)
def test_serialize_param_mutation(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
def forward(self, x):
with torch.no_grad():
self.parameter.div_(2)
return x + self.parameter
foo = Foo()
ep = torch.export.export(foo, (torch.rand(4, 4),)).run_decompositions()
buffer = io.BytesIO()
save(ep, buffer)
loaded_ep = load(buffer)
val = loaded_ep.graph_signature.parameters_to_mutate
self.assertEqual({"div": "parameter"}, val)
def test_serialize_constant_outputs(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -5364,11 +5364,15 @@ def forward(self, arg0_1, arg1_1, arg2_1):
mod = M()
inp = torch.randn(2, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError,
"Found a graph input that requires gradients, and received a mutation",
):
aot_export_module(mod, [inp], trace_joint=False)
gm, _ = aot_export_module(mod, [inp], trace_joint=False)
self.assertExpectedInline(
str(gm.graph).strip(),
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 4), kwargs = {})
return (add, add)""",
)
def test_aot_export_input_mutation_on_parameter_banned(self):
def fn(p, x):
@ -5379,11 +5383,26 @@ def forward(self, arg0_1, arg1_1, arg2_1):
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError,
"Found a graph input that requires gradients, and received a mutation",
"aot_export_joint_simple does not support input mutations. ViewAndMutationMeta",
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
with self.assertRaisesRegex(
RuntimeError,
"Found a graph input that requires gradients, and received a mutation",
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)
gm, _ = aot_export_module(mod, [inp], trace_joint=False)
self.assertExpectedInline(
str(gm.graph).strip(),
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %arg1_1), kwargs = {})
return (mul, add)""",
)
def test_aot_export_synthetic_bases_banned(self):
def fn(p, x, y):

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<e7f100132ac684ccc67fce91b241821062f1dfe496fdff4b9929aba4ac938b4f>>
// checksum<<00d94226d15b290b97bd49f9ff12bbfe04b7252c75d2d1bae66d1756fd9b8517>>
namespace py3 torch._export
namespace cpp2 torch._export.schema
@ -254,6 +254,11 @@ struct BufferMutationSpec {
20: string buffer_name;
}
struct ParameterMutationSpec {
10: TensorArgument arg;
20: string parameter_name;
}
struct GradientToParameterSpec {
10: TensorArgument arg;
20: string parameter_name;
@ -281,6 +286,7 @@ union OutputSpec {
50: GradientToUserInputSpec gradient_to_user_input;
60: UserInputMutationSpec user_input_mutation;
70: OutputTokenSpec token;
80: ParameterMutationSpec parameter_mutation;
}
struct GraphSignature {

View File

@ -327,6 +327,12 @@ class BufferMutationSpec:
buffer_name: Annotated[str, 20]
@dataclass
class ParameterMutationSpec:
arg: Annotated[TensorArgument, 10]
parameter_name: Annotated[str, 20]
@dataclass
class GradientToParameterSpec:
arg: Annotated[TensorArgument, 10]
@ -359,6 +365,7 @@ class OutputSpec(_Union):
gradient_to_user_input: Annotated[GradientToUserInputSpec, 50]
user_input_mutation: Annotated[UserInputMutationSpec, 60]
token: Annotated[OutputTokenSpec, 70]
parameter_mutation: Annotated[ParameterMutationSpec, 80]
@dataclass

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<afe0cc0f99e72d00aa05f1a94da938ecb619aabc5d131d3ade489b57799f1e5a>>
# checksum<<face83b52f81c45eeaeccc97cee19e146b3f7416ed91e015b4510ada7549a72f>>
AOTInductorModelPickleData:
kind: struct
fields:
@ -383,11 +383,20 @@ OutputSpec:
type: UserInputMutationSpec
token:
type: OutputTokenSpec
parameter_mutation:
type: ParameterMutationSpec
OutputTokenSpec:
kind: struct
fields:
arg:
type: TokenArgument
ParameterMutationSpec:
kind: struct
fields:
arg:
type: TensorArgument
parameter_name:
type: str
RangeConstraint:
kind: struct
fields:

View File

@ -69,6 +69,7 @@ from .schema import ( # type: ignore[attr-defined]
OptionalTensorArgument,
OutputSpec,
OutputTokenSpec,
ParameterMutationSpec,
RangeConstraint,
ScalarType,
SCHEMA_VERSION,
@ -1241,6 +1242,15 @@ class GraphModuleSerializer(metaclass=Final):
buffer_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.PARAMETER_MUTATION:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
return OutputSpec.create(
parameter_mutation=ParameterMutationSpec(
arg=TensorArgument(name=spec.arg.name),
parameter_name=spec.target,
)
)
elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
assert spec.target is not None
assert isinstance(spec.arg, ep.TensorArgument)
@ -2199,6 +2209,12 @@ class GraphModuleDeserializer(metaclass=Final):
arg=ep.TensorArgument(name=o.buffer_mutation.arg.name),
target=o.buffer_mutation.buffer_name,
)
elif o.type == "parameter_mutation":
return ep.OutputSpec(
kind=ep.OutputKind.PARAMETER_MUTATION,
arg=ep.TensorArgument(name=o.parameter_mutation.arg.name),
target=o.parameter_mutation.parameter_name,
)
elif o.type == "gradient_to_parameter":
return ep.OutputSpec(
kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
@ -3377,17 +3393,19 @@ def canonicalize(
idx, (_arg, spec) = out
assert isinstance(spec, OutputSpec)
if spec.type == "user_output":
return 3, None, idx
return 4, None, idx
elif spec.type == "loss_output":
return 3, None, idx
return 4, None, idx
elif spec.type == "parameter_mutation":
return 1, spec.parameter_mutation.parameter_name, idx
elif spec.type == "buffer_mutation":
return 1, spec.buffer_mutation.buffer_name, idx
return 2, spec.buffer_mutation.buffer_name, idx
elif spec.type == "gradient_to_parameter":
return 4, spec.gradient_to_parameter.parameter_name, idx
return 5, spec.gradient_to_parameter.parameter_name, idx
elif spec.type == "gradient_to_user_input":
return 5, None, idx
return 6, None, idx
elif spec.type == "user_input_mutation":
return 2, None, idx
return 3, None, idx
elif spec.type == "token":
return 0, None, idx
else:
@ -3500,6 +3518,9 @@ def canonicalize(
elif spec.type == "buffer_mutation":
t = spec.buffer_mutation.arg
t.name = replace_table[t.name]
elif spec.type == "parameter_mutation":
t = spec.parameter_mutation.arg
t.name = replace_table[t.name]
elif spec.type == "gradient_to_parameter":
t = spec.gradient_to_parameter.arg
t.name = replace_table[t.name]

View File

@ -463,7 +463,12 @@ def _verify_exported_program_signature(exported_program) -> None:
)
num_tokens = len(gs.output_tokens)
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
end = (
len(gs.buffers_to_mutate)
+ len(gs.parameters_to_mutate)
+ len(gs.user_inputs_to_mutate)
+ num_tokens
)
mutate_nodes: list[str] = output_nodes[num_tokens:end]
user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
@ -475,6 +480,13 @@ def _verify_exported_program_signature(exported_program) -> None:
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
f"Buffer nodes available: {gs.buffers} \n"
)
elif mutation_node in gs.parameters_to_mutate:
if gs.parameters_to_mutate[mutation_node] not in gs.parameters:
raise SpecViolationError(
f"Parameter output {mutation_node} does not point to a parameter that exists. \n"
f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n"
f"Parameter nodes available: {gs.parameters} \n"
)
elif mutation_node in gs.user_inputs_to_mutate:
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
raise SpecViolationError(

View File

@ -460,6 +460,7 @@ def create_graph_signature(
named_buffers=buffer_names,
num_user_inputs=num_user_args,
num_user_outputs=num_user_fw_outs,
trace_joint=trace_joint,
loss_index=loss_index,
backward_signature=backward_signature,
)

View File

@ -829,6 +829,7 @@ class GraphSignature:
# "graph outputs that correspond to updated buffers"
# to the FQN names of those mutated buffers.
buffers_to_mutate: dict[GraphOutputName, FQN]
parameters_to_mutate: dict[GraphOutputName, FQN]
user_inputs_to_mutate: dict[GraphOutputName, GraphInputName]
in_spec: pytree.TreeSpec
@ -852,6 +853,7 @@ class GraphSignature:
named_buffers: list[str],
num_user_inputs: int,
num_user_outputs: int,
trace_joint: bool,
loss_index: Optional[int],
backward_signature: Optional[BackwardSignature],
) -> GraphSignature:
@ -897,8 +899,9 @@ class GraphSignature:
mutations = []
for idx, input_info in enumerate(view_mutation_metadata.input_info):
if input_info.mutates_data:
# Only buffers can be mutated, not parameters
assert idx >= len(parameters)
if trace_joint:
# Only buffers can be mutated, not parameters
assert idx >= len(parameters)
mutations.append(names[idx + num_tokens])
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
@ -911,12 +914,16 @@ class GraphSignature:
user_inputs_to_mutate = {}
buffers_to_mutate = {}
parameters_to_mutate = {}
for output_name, mutation_name in outputs_to_mutations.items():
if mutation_name in user_inputs:
user_inputs_to_mutate[output_name] = mutation_name
else:
assert mutation_name in buffers
buffers_to_mutate[output_name] = mutation_name
assert mutation_name in buffers or mutation_name in parameters
if mutation_name in buffers:
buffers_to_mutate[output_name] = mutation_name
else:
parameters_to_mutate[output_name] = mutation_name
start, stop = stop, stop + num_user_outputs
user_outputs = graph_outputs[start:stop]
@ -937,6 +944,7 @@ class GraphSignature:
inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type]
user_inputs_to_mutate=user_inputs_to_mutate,
buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type]
parameters_to_mutate=parameters_to_mutate, # type: ignore[arg-type]
in_spec=in_spec,
out_spec=out_spec,
backward_signature=backward_signature,
@ -983,6 +991,9 @@ class AOTConfig:
ignore_shape_env: bool = False
precompile_backend_id: Optional[str] = None
force_non_lazy_backward_lowering: bool = False
# This config makes sure to check certain things like
# mutating input with req_grad in export joint tracing.
export_trace_joint: bool = False
def __post_init__(self):
if self.pre_dispatch:

View File

@ -672,6 +672,7 @@ fw_metadata={str(fw_metadata)}"""
]
)
!= 0
and aot_config.export_trace_joint
):
raise RuntimeError(
f"""\
@ -1448,6 +1449,7 @@ We require the output marked as the loss (at index {output_loss_index}) to be a
no_tangents=True,
pre_dispatch=pre_dispatch,
dynamic_shapes=dynamic_shapes,
trace_joint=trace_joint,
kwargs=kwargs,
)
@ -1550,6 +1552,7 @@ def aot_export_joint_simple(
func,
args,
decompositions=decompositions,
trace_joint=trace_joint,
)
in_spec, _kw_in_spec = in_spec.children_specs
# At this point, we can just directly return the (joint or inference graph) that we traced.
@ -1631,6 +1634,8 @@ def _aot_export_function(
# If None, `dynamic_shapes` will be inferred from inputs, but the inferred result might be wrong.
dynamic_shapes: Optional[bool] = None,
keep_input_mutations: bool = False,
# Under export, configures whether we are getting inference or training IR
trace_joint: bool = False,
kwargs=None,
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
kwargs = kwargs or {}
@ -1675,6 +1680,7 @@ def _aot_export_function(
is_export=True,
no_tangents=no_tangents,
pre_dispatch=pre_dispatch,
export_trace_joint=trace_joint,
)
if fake_mode is None:
fake_mode, shape_env = construct_fake_mode(flat_args, aot_config)

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<afe0cc0f99e72d00aa05f1a94da938ecb619aabc5d131d3ade489b57799f1e5a>>
// checksum<<face83b52f81c45eeaeccc97cee19e146b3f7416ed91e015b4510ada7549a72f>>
// clang-format off
#pragma once
@ -158,6 +158,7 @@ class Node;
class OptionalTensorArgument;
class OutputSpec;
class OutputTokenSpec;
class ParameterMutationSpec;
class RangeConstraint;
class SchemaVersion;
class SymBool;
@ -2494,6 +2495,33 @@ class BufferMutationSpec {
friend void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t);
};
class ParameterMutationSpec {
private:
TensorArgument arg;
std::string parameter_name;
public:
const TensorArgument& get_arg() const {
return arg;
}
void set_arg(TensorArgument def) {
arg = std::move(def);
}
const std::string& get_parameter_name() const {
return parameter_name;
}
void set_parameter_name(std::string def) {
parameter_name = std::move(def);
}
friend void to_json(nlohmann::json& nlohmann_json_j, const ParameterMutationSpec& nlohmann_json_t);
friend void from_json(const nlohmann::json& nlohmann_json_j, ParameterMutationSpec& nlohmann_json_t);
};
class GradientToParameterSpec {
private:
TensorArgument arg;
@ -2598,11 +2626,11 @@ class OutputSpec {
public:
enum class Tag {
USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN
USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN, PARAMETER_MUTATION
};
private:
std::variant<Void, UserOutputSpec, LossOutputSpec, BufferMutationSpec, GradientToParameterSpec, GradientToUserInputSpec, UserInputMutationSpec, OutputTokenSpec> variant_;
std::variant<Void, UserOutputSpec, LossOutputSpec, BufferMutationSpec, GradientToParameterSpec, GradientToUserInputSpec, UserInputMutationSpec, OutputTokenSpec, ParameterMutationSpec> variant_;
Tag tag_;
public:
@ -2673,6 +2701,15 @@ class OutputSpec {
tag_ = Tag::TOKEN;
}
const ParameterMutationSpec& get_parameter_mutation() const {
return std::get<8>(variant_);
}
void set_parameter_mutation(ParameterMutationSpec def) {
variant_.emplace<8>(std::move(def));
tag_ = Tag::PARAMETER_MUTATION;
}
friend void to_json(nlohmann::json& nlohmann_json_j, const OutputSpec& nlohmann_json_t) {
if (nlohmann_json_t.tag_ == Tag::USER_OUTPUT) {
@ -2703,6 +2740,10 @@ class OutputSpec {
nlohmann_json_j["token"] = nlohmann_json_t.get_token();
return;
}
if (nlohmann_json_t.tag_ == Tag::PARAMETER_MUTATION) {
nlohmann_json_j["parameter_mutation"] = nlohmann_json_t.get_parameter_mutation();
return;
}
}
friend void from_json(const nlohmann::json& nlohmann_json_j, OutputSpec& nlohmann_json_t) {
@ -2742,6 +2783,11 @@ class OutputSpec {
nlohmann_json_t.tag_ = Tag::TOKEN;
return;
}
if (nlohmann_json_j.contains("parameter_mutation")) {
nlohmann_json_t.variant_.emplace<8>(nlohmann_json_j.at("parameter_mutation").template get<ParameterMutationSpec>());
nlohmann_json_t.tag_ = Tag::PARAMETER_MUTATION;
return;
}
}
};
@ -2754,6 +2800,7 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) {
case OutputSpec::Tag::GRADIENT_TO_USER_INPUT: return "GRADIENT_TO_USER_INPUT";
case OutputSpec::Tag::USER_INPUT_MUTATION: return "USER_INPUT_MUTATION";
case OutputSpec::Tag::TOKEN: return "TOKEN";
case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION";
default:
throw std::runtime_error("Unknown enum value");
}
@ -2767,6 +2814,7 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) {
if (s == "GRADIENT_TO_USER_INPUT") { t = OutputSpec::Tag::GRADIENT_TO_USER_INPUT; return; }
if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; }
if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; }
if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s});
}
@ -3575,6 +3623,17 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nl
nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg);
}
inline void to_json(nlohmann::json& nlohmann_json_j, const ParameterMutationSpec& nlohmann_json_t) {
nlohmann_json_j["arg"] = nlohmann_json_t.arg;
nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name;
}
inline void from_json(const nlohmann::json& nlohmann_json_j, ParameterMutationSpec& nlohmann_json_t) {
ParameterMutationSpec nlohmann_json_default_obj;
nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg);
nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name);
}
inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) {
nlohmann_json_j["min_val"] = nlohmann_json_t.min_val;
nlohmann_json_j["max_val"] = nlohmann_json_t.max_val;

View File

@ -1674,7 +1674,24 @@ def _export_to_aten_ir_make_fx(
for k, (old_getattr, _) in tensor_type_to_old_getattribute.items():
k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined]
with ctx, override_getattribute_for_subclasses(flat_args):
@contextmanager
def _maybe_restore_grad_state():
"""
When pre-dispatch export accidentally change grad state, we restore it back.
This can happen when we are calling torch._C._set_grad_enabled directly in the
forward.
"""
old_state = torch.is_grad_enabled()
try:
yield
finally:
torch._C._set_grad_enabled(old_state)
with (
ctx,
override_getattribute_for_subclasses(flat_args),
_maybe_restore_grad_state(),
):
gm = make_fx(
wrapped_fn,
record_module_stack=True,
@ -1738,6 +1755,7 @@ def _export_to_aten_ir_make_fx(
zip(input_names[param_len : param_len + buffer_len], named_buffers)
),
buffers_to_mutate={},
parameters_to_mutate={},
user_inputs_to_mutate={},
in_spec=in_spec,
out_spec=out_spec.spec,
@ -1900,6 +1918,9 @@ def _non_strict_export(
_strip_root, sig.inputs_to_parameters
)
sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate)
sig.parameters_to_mutate = pytree.tree_map(
_strip_root, sig.parameters_to_mutate
)
for node in gm.graph.nodes:
if "nn_module_stack" in node.meta:

View File

@ -447,7 +447,11 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.Grap
(
out_spec.target
if out_spec.kind
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
in (
OutputKind.BUFFER_MUTATION,
OutputKind.USER_INPUT_MUTATION,
OutputKind.PARAMETER_MUTATION,
)
else None
)
for out_spec in ep.graph_signature.output_specs

View File

@ -792,9 +792,9 @@ def _remove_unneccessary_copy_op_pass(
if node.op == "output":
args, _ = pytree.tree_flatten(node.args)
for out in args:
if (
isinstance(out, torch.fx.Node)
and out.name in new_graph_signature.buffers_to_mutate
if isinstance(out, torch.fx.Node) and (
out.name in new_graph_signature.buffers_to_mutate
or out.name in new_graph_signature.parameters_to_mutate
):
if (
out.op == "call_function"

View File

@ -121,6 +121,7 @@ class OutputKind(Enum):
USER_OUTPUT = auto()
LOSS_OUTPUT = auto()
BUFFER_MUTATION = auto()
PARAMETER_MUTATION = auto()
GRADIENT_TO_PARAMETER = auto()
GRADIENT_TO_USER_INPUT = auto()
USER_INPUT_MUTATION = auto()
@ -406,6 +407,16 @@ class ExportGraphSignature:
and isinstance(s.target, str)
)
@property
def parameters_to_mutate(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target)
for s in self.output_specs
if s.kind == OutputKind.PARAMETER_MUTATION
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
)
@property
def user_inputs_to_mutate(self) -> Mapping[str, str]:
return _immutable_dict(
@ -601,6 +612,7 @@ def _convert_to_export_graph_signature(
inputs_to_buffers = graph_signature.inputs_to_buffers
user_outputs = set(graph_signature.user_outputs)
buffer_mutations = graph_signature.buffers_to_mutate
parameter_mutations = graph_signature.parameters_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate
grad_params = (
graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
@ -662,13 +674,21 @@ def _convert_to_export_graph_signature(
if not isinstance(o, TensorArgument):
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
name = o.name
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
if idx < len(buffer_mutations) + len(parameter_mutations) + len(
user_input_mutations
) + len(output_tokens):
if name in buffer_mutations:
return OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=o,
target=buffer_mutations[name], # type: ignore[index]
)
elif name in parameter_mutations:
return OutputSpec(
kind=OutputKind.PARAMETER_MUTATION,
arg=o,
target=parameter_mutations[name], # type: ignore[index]
)
elif name in user_input_mutations:
return OutputSpec(
kind=OutputKind.USER_INPUT_MUTATION,

View File

@ -1430,9 +1430,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
torch.amp.autocast_mode._exit_autocast,
]:
node.meta["val"] = None
# For autocast, the python APIs run so we don't have to run them again
# here.
if func is torch._C._set_grad_enabled:
func(*args, **kwargs)
return node
# Don't actually run the function! We just want to trace the calls
# into a graph. We don't actually want to change global autograd state.
return func(*args, **kwargs)