mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for param mutation under inference mode (#159661)
Summary: In HF model rwkv, we have parameter mutation under inference mode which should be safe. This PR does multiple things to make sure it works: 1. We execute global autograd mutation while tracing so that we can actually trace through parameter inplace mutation 2. Add support for parameter mutation under inference mode in AOTAutograd 3. Add support for parameter mutation under inference mode in export. Test Plan: test Rollback Plan: Differential Revision: D79460136 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159661 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
29d20d49f0
commit
194fcfcfbd
@ -326,6 +326,52 @@ class TestDynamismExpression(TestCase):
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_no_grad_param_inplace(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
|
||||
|
||||
def forward(self, x):
|
||||
with torch.no_grad():
|
||||
self.parameter.div_(2)
|
||||
return x + self.parameter
|
||||
|
||||
foo_ep = Foo()
|
||||
foo_eager = Foo()
|
||||
ep = export(foo_ep, (torch.rand(4, 4),)).run_decompositions()
|
||||
val = ep.graph_signature.parameters_to_mutate
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%p_parameter : [num_users=1] = placeholder[target=p_parameter]
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%p_parameter, 2), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %div), kwargs = {})
|
||||
return (div, add)""",
|
||||
)
|
||||
|
||||
self.assertTrue("div" in val.keys())
|
||||
self.assertTrue("parameter" in val.values())
|
||||
|
||||
test_inp = torch.rand(4, 4)
|
||||
|
||||
res = foo_eager(test_inp)
|
||||
|
||||
# TODO We almost need to make the param mutation happen outside
|
||||
# of the graph. Or wrap the param mutation in a no_grad HOP. Simply
|
||||
# overriding gm.__call__ doesn't seem to work due to:
|
||||
# 1. graph module does something weird to __call__ so it is not easy to override
|
||||
# 2. We inspect module.forward to bind fake args when retracing
|
||||
with self.assertRaisesRegex(RuntimeError, "leaf"):
|
||||
res_export = ep.module()(torch.rand(4, 4))
|
||||
|
||||
with torch.no_grad():
|
||||
res_export = ep.module()(test_inp)
|
||||
|
||||
self.assertTrue(torch.allclose(res, res_export))
|
||||
|
||||
def test_export_slice_unbacked_dim1(self):
|
||||
class MySlice(torch.nn.Module):
|
||||
def forward(self, x, seq_len):
|
||||
@ -4000,6 +4046,17 @@ def forward(self, x):
|
||||
inp = torch.randn(3, 3)
|
||||
self.assertTrue(torch.allclose(ep.module()(inp)[0], inp + 1))
|
||||
|
||||
def test_set_grad_as_side_effect(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
torch._C._set_grad_enabled(False)
|
||||
return x.sum()
|
||||
|
||||
before = torch.is_grad_enabled()
|
||||
ep = torch.export.export(Foo(), (torch.randn(4, 4),))
|
||||
after = torch.is_grad_enabled()
|
||||
self.assertEqual(before, after)
|
||||
|
||||
def test_derived_dim_out_of_order_simplified(self):
|
||||
_dimz = torch.export.Dim("_dimz", min=6, max=8)
|
||||
dimy = _dimz - 1
|
||||
|
@ -280,6 +280,25 @@ def forward(self, x):
|
||||
actual_out = loaded_ep.module()(*inp)
|
||||
self.assertEqual(exp_out, actual_out)
|
||||
|
||||
def test_serialize_param_mutation(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
|
||||
|
||||
def forward(self, x):
|
||||
with torch.no_grad():
|
||||
self.parameter.div_(2)
|
||||
return x + self.parameter
|
||||
|
||||
foo = Foo()
|
||||
ep = torch.export.export(foo, (torch.rand(4, 4),)).run_decompositions()
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
val = loaded_ep.graph_signature.parameters_to_mutate
|
||||
self.assertEqual({"div": "parameter"}, val)
|
||||
|
||||
def test_serialize_constant_outputs(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -5364,11 +5364,15 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
|
||||
mod = M()
|
||||
inp = torch.randn(2, requires_grad=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found a graph input that requires gradients, and received a mutation",
|
||||
):
|
||||
aot_export_module(mod, [inp], trace_joint=False)
|
||||
gm, _ = aot_export_module(mod, [inp], trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 4), kwargs = {})
|
||||
return (add, add)""",
|
||||
)
|
||||
|
||||
def test_aot_export_input_mutation_on_parameter_banned(self):
|
||||
def fn(p, x):
|
||||
@ -5379,11 +5383,26 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
inp = torch.randn(2)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found a graph input that requires gradients, and received a mutation",
|
||||
"aot_export_joint_simple does not support input mutations. ViewAndMutationMeta",
|
||||
):
|
||||
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found a graph input that requires gradients, and received a mutation",
|
||||
):
|
||||
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
|
||||
aot_export_module(mod, [inp], trace_joint=False)
|
||||
|
||||
gm, _ = aot_export_module(mod, [inp], trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
||||
%mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %arg1_1), kwargs = {})
|
||||
return (mul, add)""",
|
||||
)
|
||||
|
||||
def test_aot_export_synthetic_bases_banned(self):
|
||||
def fn(p, x, y):
|
||||
|
@ -1,5 +1,5 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<e7f100132ac684ccc67fce91b241821062f1dfe496fdff4b9929aba4ac938b4f>>
|
||||
// checksum<<00d94226d15b290b97bd49f9ff12bbfe04b7252c75d2d1bae66d1756fd9b8517>>
|
||||
|
||||
namespace py3 torch._export
|
||||
namespace cpp2 torch._export.schema
|
||||
@ -254,6 +254,11 @@ struct BufferMutationSpec {
|
||||
20: string buffer_name;
|
||||
}
|
||||
|
||||
struct ParameterMutationSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string parameter_name;
|
||||
}
|
||||
|
||||
struct GradientToParameterSpec {
|
||||
10: TensorArgument arg;
|
||||
20: string parameter_name;
|
||||
@ -281,6 +286,7 @@ union OutputSpec {
|
||||
50: GradientToUserInputSpec gradient_to_user_input;
|
||||
60: UserInputMutationSpec user_input_mutation;
|
||||
70: OutputTokenSpec token;
|
||||
80: ParameterMutationSpec parameter_mutation;
|
||||
}
|
||||
|
||||
struct GraphSignature {
|
||||
|
@ -327,6 +327,12 @@ class BufferMutationSpec:
|
||||
buffer_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterMutationSpec:
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
parameter_name: Annotated[str, 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradientToParameterSpec:
|
||||
arg: Annotated[TensorArgument, 10]
|
||||
@ -359,6 +365,7 @@ class OutputSpec(_Union):
|
||||
gradient_to_user_input: Annotated[GradientToUserInputSpec, 50]
|
||||
user_input_mutation: Annotated[UserInputMutationSpec, 60]
|
||||
token: Annotated[OutputTokenSpec, 70]
|
||||
parameter_mutation: Annotated[ParameterMutationSpec, 80]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,5 +1,5 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<afe0cc0f99e72d00aa05f1a94da938ecb619aabc5d131d3ade489b57799f1e5a>>
|
||||
# checksum<<face83b52f81c45eeaeccc97cee19e146b3f7416ed91e015b4510ada7549a72f>>
|
||||
AOTInductorModelPickleData:
|
||||
kind: struct
|
||||
fields:
|
||||
@ -383,11 +383,20 @@ OutputSpec:
|
||||
type: UserInputMutationSpec
|
||||
token:
|
||||
type: OutputTokenSpec
|
||||
parameter_mutation:
|
||||
type: ParameterMutationSpec
|
||||
OutputTokenSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TokenArgument
|
||||
ParameterMutationSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
arg:
|
||||
type: TensorArgument
|
||||
parameter_name:
|
||||
type: str
|
||||
RangeConstraint:
|
||||
kind: struct
|
||||
fields:
|
||||
|
@ -69,6 +69,7 @@ from .schema import ( # type: ignore[attr-defined]
|
||||
OptionalTensorArgument,
|
||||
OutputSpec,
|
||||
OutputTokenSpec,
|
||||
ParameterMutationSpec,
|
||||
RangeConstraint,
|
||||
ScalarType,
|
||||
SCHEMA_VERSION,
|
||||
@ -1241,6 +1242,15 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
buffer_name=spec.target,
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.OutputKind.PARAMETER_MUTATION:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, ep.TensorArgument)
|
||||
return OutputSpec.create(
|
||||
parameter_mutation=ParameterMutationSpec(
|
||||
arg=TensorArgument(name=spec.arg.name),
|
||||
parameter_name=spec.target,
|
||||
)
|
||||
)
|
||||
elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
|
||||
assert spec.target is not None
|
||||
assert isinstance(spec.arg, ep.TensorArgument)
|
||||
@ -2199,6 +2209,12 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
arg=ep.TensorArgument(name=o.buffer_mutation.arg.name),
|
||||
target=o.buffer_mutation.buffer_name,
|
||||
)
|
||||
elif o.type == "parameter_mutation":
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.PARAMETER_MUTATION,
|
||||
arg=ep.TensorArgument(name=o.parameter_mutation.arg.name),
|
||||
target=o.parameter_mutation.parameter_name,
|
||||
)
|
||||
elif o.type == "gradient_to_parameter":
|
||||
return ep.OutputSpec(
|
||||
kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
|
||||
@ -3377,17 +3393,19 @@ def canonicalize(
|
||||
idx, (_arg, spec) = out
|
||||
assert isinstance(spec, OutputSpec)
|
||||
if spec.type == "user_output":
|
||||
return 3, None, idx
|
||||
return 4, None, idx
|
||||
elif spec.type == "loss_output":
|
||||
return 3, None, idx
|
||||
return 4, None, idx
|
||||
elif spec.type == "parameter_mutation":
|
||||
return 1, spec.parameter_mutation.parameter_name, idx
|
||||
elif spec.type == "buffer_mutation":
|
||||
return 1, spec.buffer_mutation.buffer_name, idx
|
||||
return 2, spec.buffer_mutation.buffer_name, idx
|
||||
elif spec.type == "gradient_to_parameter":
|
||||
return 4, spec.gradient_to_parameter.parameter_name, idx
|
||||
return 5, spec.gradient_to_parameter.parameter_name, idx
|
||||
elif spec.type == "gradient_to_user_input":
|
||||
return 5, None, idx
|
||||
return 6, None, idx
|
||||
elif spec.type == "user_input_mutation":
|
||||
return 2, None, idx
|
||||
return 3, None, idx
|
||||
elif spec.type == "token":
|
||||
return 0, None, idx
|
||||
else:
|
||||
@ -3500,6 +3518,9 @@ def canonicalize(
|
||||
elif spec.type == "buffer_mutation":
|
||||
t = spec.buffer_mutation.arg
|
||||
t.name = replace_table[t.name]
|
||||
elif spec.type == "parameter_mutation":
|
||||
t = spec.parameter_mutation.arg
|
||||
t.name = replace_table[t.name]
|
||||
elif spec.type == "gradient_to_parameter":
|
||||
t = spec.gradient_to_parameter.arg
|
||||
t.name = replace_table[t.name]
|
||||
|
@ -463,7 +463,12 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
)
|
||||
|
||||
num_tokens = len(gs.output_tokens)
|
||||
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
|
||||
end = (
|
||||
len(gs.buffers_to_mutate)
|
||||
+ len(gs.parameters_to_mutate)
|
||||
+ len(gs.user_inputs_to_mutate)
|
||||
+ num_tokens
|
||||
)
|
||||
mutate_nodes: list[str] = output_nodes[num_tokens:end]
|
||||
user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
|
||||
|
||||
@ -475,6 +480,13 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
|
||||
f"Buffer nodes available: {gs.buffers} \n"
|
||||
)
|
||||
elif mutation_node in gs.parameters_to_mutate:
|
||||
if gs.parameters_to_mutate[mutation_node] not in gs.parameters:
|
||||
raise SpecViolationError(
|
||||
f"Parameter output {mutation_node} does not point to a parameter that exists. \n"
|
||||
f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n"
|
||||
f"Parameter nodes available: {gs.parameters} \n"
|
||||
)
|
||||
elif mutation_node in gs.user_inputs_to_mutate:
|
||||
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
|
||||
raise SpecViolationError(
|
||||
|
@ -460,6 +460,7 @@ def create_graph_signature(
|
||||
named_buffers=buffer_names,
|
||||
num_user_inputs=num_user_args,
|
||||
num_user_outputs=num_user_fw_outs,
|
||||
trace_joint=trace_joint,
|
||||
loss_index=loss_index,
|
||||
backward_signature=backward_signature,
|
||||
)
|
||||
|
@ -829,6 +829,7 @@ class GraphSignature:
|
||||
# "graph outputs that correspond to updated buffers"
|
||||
# to the FQN names of those mutated buffers.
|
||||
buffers_to_mutate: dict[GraphOutputName, FQN]
|
||||
parameters_to_mutate: dict[GraphOutputName, FQN]
|
||||
user_inputs_to_mutate: dict[GraphOutputName, GraphInputName]
|
||||
|
||||
in_spec: pytree.TreeSpec
|
||||
@ -852,6 +853,7 @@ class GraphSignature:
|
||||
named_buffers: list[str],
|
||||
num_user_inputs: int,
|
||||
num_user_outputs: int,
|
||||
trace_joint: bool,
|
||||
loss_index: Optional[int],
|
||||
backward_signature: Optional[BackwardSignature],
|
||||
) -> GraphSignature:
|
||||
@ -897,8 +899,9 @@ class GraphSignature:
|
||||
mutations = []
|
||||
for idx, input_info in enumerate(view_mutation_metadata.input_info):
|
||||
if input_info.mutates_data:
|
||||
# Only buffers can be mutated, not parameters
|
||||
assert idx >= len(parameters)
|
||||
if trace_joint:
|
||||
# Only buffers can be mutated, not parameters
|
||||
assert idx >= len(parameters)
|
||||
mutations.append(names[idx + num_tokens])
|
||||
|
||||
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
|
||||
@ -911,12 +914,16 @@ class GraphSignature:
|
||||
|
||||
user_inputs_to_mutate = {}
|
||||
buffers_to_mutate = {}
|
||||
parameters_to_mutate = {}
|
||||
for output_name, mutation_name in outputs_to_mutations.items():
|
||||
if mutation_name in user_inputs:
|
||||
user_inputs_to_mutate[output_name] = mutation_name
|
||||
else:
|
||||
assert mutation_name in buffers
|
||||
buffers_to_mutate[output_name] = mutation_name
|
||||
assert mutation_name in buffers or mutation_name in parameters
|
||||
if mutation_name in buffers:
|
||||
buffers_to_mutate[output_name] = mutation_name
|
||||
else:
|
||||
parameters_to_mutate[output_name] = mutation_name
|
||||
|
||||
start, stop = stop, stop + num_user_outputs
|
||||
user_outputs = graph_outputs[start:stop]
|
||||
@ -937,6 +944,7 @@ class GraphSignature:
|
||||
inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type]
|
||||
user_inputs_to_mutate=user_inputs_to_mutate,
|
||||
buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type]
|
||||
parameters_to_mutate=parameters_to_mutate, # type: ignore[arg-type]
|
||||
in_spec=in_spec,
|
||||
out_spec=out_spec,
|
||||
backward_signature=backward_signature,
|
||||
@ -983,6 +991,9 @@ class AOTConfig:
|
||||
ignore_shape_env: bool = False
|
||||
precompile_backend_id: Optional[str] = None
|
||||
force_non_lazy_backward_lowering: bool = False
|
||||
# This config makes sure to check certain things like
|
||||
# mutating input with req_grad in export joint tracing.
|
||||
export_trace_joint: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pre_dispatch:
|
||||
|
@ -672,6 +672,7 @@ fw_metadata={str(fw_metadata)}"""
|
||||
]
|
||||
)
|
||||
!= 0
|
||||
and aot_config.export_trace_joint
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"""\
|
||||
@ -1448,6 +1449,7 @@ We require the output marked as the loss (at index {output_loss_index}) to be a
|
||||
no_tangents=True,
|
||||
pre_dispatch=pre_dispatch,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
trace_joint=trace_joint,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
@ -1550,6 +1552,7 @@ def aot_export_joint_simple(
|
||||
func,
|
||||
args,
|
||||
decompositions=decompositions,
|
||||
trace_joint=trace_joint,
|
||||
)
|
||||
in_spec, _kw_in_spec = in_spec.children_specs
|
||||
# At this point, we can just directly return the (joint or inference graph) that we traced.
|
||||
@ -1631,6 +1634,8 @@ def _aot_export_function(
|
||||
# If None, `dynamic_shapes` will be inferred from inputs, but the inferred result might be wrong.
|
||||
dynamic_shapes: Optional[bool] = None,
|
||||
keep_input_mutations: bool = False,
|
||||
# Under export, configures whether we are getting inference or training IR
|
||||
trace_joint: bool = False,
|
||||
kwargs=None,
|
||||
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
|
||||
kwargs = kwargs or {}
|
||||
@ -1675,6 +1680,7 @@ def _aot_export_function(
|
||||
is_export=True,
|
||||
no_tangents=no_tangents,
|
||||
pre_dispatch=pre_dispatch,
|
||||
export_trace_joint=trace_joint,
|
||||
)
|
||||
if fake_mode is None:
|
||||
fake_mode, shape_env = construct_fake_mode(flat_args, aot_config)
|
||||
|
65
torch/csrc/utils/generated_serialization_types.h
generated
65
torch/csrc/utils/generated_serialization_types.h
generated
@ -1,5 +1,5 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<afe0cc0f99e72d00aa05f1a94da938ecb619aabc5d131d3ade489b57799f1e5a>>
|
||||
// checksum<<face83b52f81c45eeaeccc97cee19e146b3f7416ed91e015b4510ada7549a72f>>
|
||||
// clang-format off
|
||||
|
||||
#pragma once
|
||||
@ -158,6 +158,7 @@ class Node;
|
||||
class OptionalTensorArgument;
|
||||
class OutputSpec;
|
||||
class OutputTokenSpec;
|
||||
class ParameterMutationSpec;
|
||||
class RangeConstraint;
|
||||
class SchemaVersion;
|
||||
class SymBool;
|
||||
@ -2494,6 +2495,33 @@ class BufferMutationSpec {
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t);
|
||||
};
|
||||
|
||||
class ParameterMutationSpec {
|
||||
private:
|
||||
TensorArgument arg;
|
||||
std::string parameter_name;
|
||||
|
||||
public:
|
||||
|
||||
const TensorArgument& get_arg() const {
|
||||
return arg;
|
||||
}
|
||||
|
||||
void set_arg(TensorArgument def) {
|
||||
arg = std::move(def);
|
||||
}
|
||||
|
||||
const std::string& get_parameter_name() const {
|
||||
return parameter_name;
|
||||
}
|
||||
|
||||
void set_parameter_name(std::string def) {
|
||||
parameter_name = std::move(def);
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const ParameterMutationSpec& nlohmann_json_t);
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, ParameterMutationSpec& nlohmann_json_t);
|
||||
};
|
||||
|
||||
class GradientToParameterSpec {
|
||||
private:
|
||||
TensorArgument arg;
|
||||
@ -2598,11 +2626,11 @@ class OutputSpec {
|
||||
|
||||
public:
|
||||
enum class Tag {
|
||||
USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN
|
||||
USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN, PARAMETER_MUTATION
|
||||
};
|
||||
|
||||
private:
|
||||
std::variant<Void, UserOutputSpec, LossOutputSpec, BufferMutationSpec, GradientToParameterSpec, GradientToUserInputSpec, UserInputMutationSpec, OutputTokenSpec> variant_;
|
||||
std::variant<Void, UserOutputSpec, LossOutputSpec, BufferMutationSpec, GradientToParameterSpec, GradientToUserInputSpec, UserInputMutationSpec, OutputTokenSpec, ParameterMutationSpec> variant_;
|
||||
Tag tag_;
|
||||
|
||||
public:
|
||||
@ -2673,6 +2701,15 @@ class OutputSpec {
|
||||
tag_ = Tag::TOKEN;
|
||||
}
|
||||
|
||||
const ParameterMutationSpec& get_parameter_mutation() const {
|
||||
return std::get<8>(variant_);
|
||||
}
|
||||
|
||||
void set_parameter_mutation(ParameterMutationSpec def) {
|
||||
variant_.emplace<8>(std::move(def));
|
||||
tag_ = Tag::PARAMETER_MUTATION;
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const OutputSpec& nlohmann_json_t) {
|
||||
|
||||
if (nlohmann_json_t.tag_ == Tag::USER_OUTPUT) {
|
||||
@ -2703,6 +2740,10 @@ class OutputSpec {
|
||||
nlohmann_json_j["token"] = nlohmann_json_t.get_token();
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_t.tag_ == Tag::PARAMETER_MUTATION) {
|
||||
nlohmann_json_j["parameter_mutation"] = nlohmann_json_t.get_parameter_mutation();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, OutputSpec& nlohmann_json_t) {
|
||||
@ -2742,6 +2783,11 @@ class OutputSpec {
|
||||
nlohmann_json_t.tag_ = Tag::TOKEN;
|
||||
return;
|
||||
}
|
||||
if (nlohmann_json_j.contains("parameter_mutation")) {
|
||||
nlohmann_json_t.variant_.emplace<8>(nlohmann_json_j.at("parameter_mutation").template get<ParameterMutationSpec>());
|
||||
nlohmann_json_t.tag_ = Tag::PARAMETER_MUTATION;
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -2754,6 +2800,7 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) {
|
||||
case OutputSpec::Tag::GRADIENT_TO_USER_INPUT: return "GRADIENT_TO_USER_INPUT";
|
||||
case OutputSpec::Tag::USER_INPUT_MUTATION: return "USER_INPUT_MUTATION";
|
||||
case OutputSpec::Tag::TOKEN: return "TOKEN";
|
||||
case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION";
|
||||
default:
|
||||
throw std::runtime_error("Unknown enum value");
|
||||
}
|
||||
@ -2767,6 +2814,7 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) {
|
||||
if (s == "GRADIENT_TO_USER_INPUT") { t = OutputSpec::Tag::GRADIENT_TO_USER_INPUT; return; }
|
||||
if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; }
|
||||
if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; }
|
||||
if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; }
|
||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
||||
}
|
||||
|
||||
@ -3575,6 +3623,17 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nl
|
||||
nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg);
|
||||
}
|
||||
|
||||
inline void to_json(nlohmann::json& nlohmann_json_j, const ParameterMutationSpec& nlohmann_json_t) {
|
||||
nlohmann_json_j["arg"] = nlohmann_json_t.arg;
|
||||
nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name;
|
||||
}
|
||||
|
||||
inline void from_json(const nlohmann::json& nlohmann_json_j, ParameterMutationSpec& nlohmann_json_t) {
|
||||
ParameterMutationSpec nlohmann_json_default_obj;
|
||||
nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg);
|
||||
nlohmann_json_t.parameter_name = nlohmann_json_j.value("parameter_name", nlohmann_json_default_obj.parameter_name);
|
||||
}
|
||||
|
||||
inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) {
|
||||
nlohmann_json_j["min_val"] = nlohmann_json_t.min_val;
|
||||
nlohmann_json_j["max_val"] = nlohmann_json_t.max_val;
|
||||
|
@ -1674,7 +1674,24 @@ def _export_to_aten_ir_make_fx(
|
||||
for k, (old_getattr, _) in tensor_type_to_old_getattribute.items():
|
||||
k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined]
|
||||
|
||||
with ctx, override_getattribute_for_subclasses(flat_args):
|
||||
@contextmanager
|
||||
def _maybe_restore_grad_state():
|
||||
"""
|
||||
When pre-dispatch export accidentally change grad state, we restore it back.
|
||||
This can happen when we are calling torch._C._set_grad_enabled directly in the
|
||||
forward.
|
||||
"""
|
||||
old_state = torch.is_grad_enabled()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_grad_enabled(old_state)
|
||||
|
||||
with (
|
||||
ctx,
|
||||
override_getattribute_for_subclasses(flat_args),
|
||||
_maybe_restore_grad_state(),
|
||||
):
|
||||
gm = make_fx(
|
||||
wrapped_fn,
|
||||
record_module_stack=True,
|
||||
@ -1738,6 +1755,7 @@ def _export_to_aten_ir_make_fx(
|
||||
zip(input_names[param_len : param_len + buffer_len], named_buffers)
|
||||
),
|
||||
buffers_to_mutate={},
|
||||
parameters_to_mutate={},
|
||||
user_inputs_to_mutate={},
|
||||
in_spec=in_spec,
|
||||
out_spec=out_spec.spec,
|
||||
@ -1900,6 +1918,9 @@ def _non_strict_export(
|
||||
_strip_root, sig.inputs_to_parameters
|
||||
)
|
||||
sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate)
|
||||
sig.parameters_to_mutate = pytree.tree_map(
|
||||
_strip_root, sig.parameters_to_mutate
|
||||
)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if "nn_module_stack" in node.meta:
|
||||
|
@ -447,7 +447,11 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.Grap
|
||||
(
|
||||
out_spec.target
|
||||
if out_spec.kind
|
||||
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
|
||||
in (
|
||||
OutputKind.BUFFER_MUTATION,
|
||||
OutputKind.USER_INPUT_MUTATION,
|
||||
OutputKind.PARAMETER_MUTATION,
|
||||
)
|
||||
else None
|
||||
)
|
||||
for out_spec in ep.graph_signature.output_specs
|
||||
|
@ -792,9 +792,9 @@ def _remove_unneccessary_copy_op_pass(
|
||||
if node.op == "output":
|
||||
args, _ = pytree.tree_flatten(node.args)
|
||||
for out in args:
|
||||
if (
|
||||
isinstance(out, torch.fx.Node)
|
||||
and out.name in new_graph_signature.buffers_to_mutate
|
||||
if isinstance(out, torch.fx.Node) and (
|
||||
out.name in new_graph_signature.buffers_to_mutate
|
||||
or out.name in new_graph_signature.parameters_to_mutate
|
||||
):
|
||||
if (
|
||||
out.op == "call_function"
|
||||
|
@ -121,6 +121,7 @@ class OutputKind(Enum):
|
||||
USER_OUTPUT = auto()
|
||||
LOSS_OUTPUT = auto()
|
||||
BUFFER_MUTATION = auto()
|
||||
PARAMETER_MUTATION = auto()
|
||||
GRADIENT_TO_PARAMETER = auto()
|
||||
GRADIENT_TO_USER_INPUT = auto()
|
||||
USER_INPUT_MUTATION = auto()
|
||||
@ -406,6 +407,16 @@ class ExportGraphSignature:
|
||||
and isinstance(s.target, str)
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters_to_mutate(self) -> Mapping[str, str]:
|
||||
return _immutable_dict(
|
||||
(s.arg.name, s.target)
|
||||
for s in self.output_specs
|
||||
if s.kind == OutputKind.PARAMETER_MUTATION
|
||||
and isinstance(s.arg, TensorArgument)
|
||||
and isinstance(s.target, str)
|
||||
)
|
||||
|
||||
@property
|
||||
def user_inputs_to_mutate(self) -> Mapping[str, str]:
|
||||
return _immutable_dict(
|
||||
@ -601,6 +612,7 @@ def _convert_to_export_graph_signature(
|
||||
inputs_to_buffers = graph_signature.inputs_to_buffers
|
||||
user_outputs = set(graph_signature.user_outputs)
|
||||
buffer_mutations = graph_signature.buffers_to_mutate
|
||||
parameter_mutations = graph_signature.parameters_to_mutate
|
||||
user_input_mutations = graph_signature.user_inputs_to_mutate
|
||||
grad_params = (
|
||||
graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
|
||||
@ -662,13 +674,21 @@ def _convert_to_export_graph_signature(
|
||||
if not isinstance(o, TensorArgument):
|
||||
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
|
||||
name = o.name
|
||||
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
|
||||
if idx < len(buffer_mutations) + len(parameter_mutations) + len(
|
||||
user_input_mutations
|
||||
) + len(output_tokens):
|
||||
if name in buffer_mutations:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.BUFFER_MUTATION,
|
||||
arg=o,
|
||||
target=buffer_mutations[name], # type: ignore[index]
|
||||
)
|
||||
elif name in parameter_mutations:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.PARAMETER_MUTATION,
|
||||
arg=o,
|
||||
target=parameter_mutations[name], # type: ignore[index]
|
||||
)
|
||||
elif name in user_input_mutations:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.USER_INPUT_MUTATION,
|
||||
|
@ -1430,9 +1430,11 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
torch.amp.autocast_mode._exit_autocast,
|
||||
]:
|
||||
node.meta["val"] = None
|
||||
# For autocast, the python APIs run so we don't have to run them again
|
||||
# here.
|
||||
if func is torch._C._set_grad_enabled:
|
||||
func(*args, **kwargs)
|
||||
return node
|
||||
# Don't actually run the function! We just want to trace the calls
|
||||
# into a graph. We don't actually want to change global autograd state.
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user