mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][hops] Remove const outputs from the speculated subgraph (#161355)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161355 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
9480cdc0b6
commit
6b1900c22f
@ -2608,25 +2608,17 @@ class GraphModule(torch.nn.Module):
|
||||
f, default_args_generator((x,)), arg_count, expected_opcount=3
|
||||
)
|
||||
|
||||
def test_fallback_on_python_primitives_output(self):
|
||||
def test_support_float_in_output(self):
|
||||
counters.clear()
|
||||
cnt = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnt)
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def f(x):
|
||||
return wrap(lambda x: [1, torch.sin(x), 2.0], x)
|
||||
|
||||
x = torch.randn(3)
|
||||
result = f(x)
|
||||
self.assertEqual(result, [1, torch.sin(x), 2.0])
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
assert_dict_matches_regex(
|
||||
self,
|
||||
dict(counters["graph_break"]),
|
||||
{
|
||||
".*HigherOrderOperator body's output must consist of tensors or ints only but got": 1
|
||||
},
|
||||
)
|
||||
|
||||
def test_nested_tuple_output(self):
|
||||
def f(x):
|
||||
|
@ -1498,7 +1498,7 @@ class GraphModule(torch.nn.Module):
|
||||
subgraph_0 = self.subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_); subgraph_0 = l_x_ = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None
|
||||
|
||||
add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None
|
||||
return (add,)
|
||||
@ -1507,7 +1507,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
child: "f32[8, 8]" = l_x_ * 2
|
||||
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
|
||||
return (child, None, child_1)
|
||||
return (child, child_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1520,16 +1520,16 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = primals_1 = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph_2[0]
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_2[1]; invoke_subgraph_2 = None
|
||||
|
||||
add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_2); getitem = getitem_2 = None
|
||||
add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return (add,)
|
||||
|
||||
class partitioned_fw_subgraph_0_0(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[8, 8]"):
|
||||
mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 2)
|
||||
mul_1: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 3); primals_0 = None
|
||||
return (mul, None, mul_1)
|
||||
return (mul, mul_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1541,8 +1541,8 @@ class GraphModule(torch.nn.Module):
|
||||
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', tangents_1, tangents_1); partitioned_bw_subgraph_0_0 = tangents_1 = None
|
||||
getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
return (getitem_3,)
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
return (getitem_2,)
|
||||
|
||||
class partitioned_bw_subgraph_0_0(torch.nn.Module):
|
||||
def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"):
|
||||
@ -1888,6 +1888,37 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
def test_return_size(self):
|
||||
def run(dynamic):
|
||||
torch.compiler.reset()
|
||||
|
||||
@nested_compile_region
|
||||
def gn(x):
|
||||
y = x + 1
|
||||
z = x.shape
|
||||
return y, z
|
||||
|
||||
def fn(x):
|
||||
z0 = gn(x)
|
||||
z1 = gn(x)
|
||||
return z0[0] + z1[0], z0[1]
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(
|
||||
fn, backend="inductor", fullgraph=True, dynamic=dynamic
|
||||
)
|
||||
res = opt_fn(x_clone)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
ref[0].sum().backward()
|
||||
res[0].sum().backward()
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
run(dynamic=True)
|
||||
run(dynamic=False)
|
||||
|
||||
def test_different_symint(self):
|
||||
"""
|
||||
Tests check that the same subgraph called with different symints use different graphs
|
||||
|
@ -248,3 +248,33 @@ def wrap_inline_with_error_on_graph_break(
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def filter_out_const_values(tup: tuple[Any, ...], masks: list[bool]) -> tuple[Any, ...]:
|
||||
"""
|
||||
masks is a list of bools, where True means the corresponding element in tup
|
||||
is a const value. Filter out the const values.
|
||||
"""
|
||||
out = []
|
||||
for mask_idx, mask in enumerate(masks):
|
||||
if not mask:
|
||||
out.append(tup[mask_idx])
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def insert_const_values_with_mask(
|
||||
tup: tuple[Any, ...], masks: list[bool], values: tuple[Any, ...]
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
masks and values are of same length. For indices where the mask is True, use
|
||||
the const_values to fill in.
|
||||
"""
|
||||
out = []
|
||||
idx = 0
|
||||
for mask_idx, mask in enumerate(masks):
|
||||
if mask:
|
||||
out.append(values[mask_idx])
|
||||
else:
|
||||
out.append(tup[idx])
|
||||
idx += 1
|
||||
return tuple(out)
|
||||
|
@ -28,7 +28,7 @@ import types
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.fx
|
||||
@ -74,10 +74,28 @@ hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
|
||||
@dataclass
|
||||
class OutputSpec:
|
||||
"""
|
||||
The treespec of the output of the speculated subgraph and other metadata.
|
||||
Contains the treespec of the output of the speculated subgraph, and the
|
||||
information to mask out the constant values from the output during
|
||||
flattening and inserting them back during unflattening. Cleaning up
|
||||
constants from the graph makes the graph simpler for AOTDispatcher and
|
||||
Inductor.
|
||||
"""
|
||||
|
||||
treespec: pytree.TreeSpec
|
||||
# list of True/False to identify the locations of const values in the
|
||||
# subgraph output. True means that value at that index is a constant.
|
||||
masks_to_filter_const_values: Optional[list[bool]] = None
|
||||
# The actual constant values that were present in the subgraph output. Note
|
||||
# that this is the same length as the mask, we just look at the indices
|
||||
# where mask is True.
|
||||
const_values: Optional[list[Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.masks_to_filter_const_values is not None
|
||||
or self.const_values is not None
|
||||
):
|
||||
assert len(self.masks_to_filter_const_values) == len(self.const_values)
|
||||
|
||||
|
||||
def raise_hard_error_if_graph_break(reason):
|
||||
@ -242,6 +260,15 @@ def _call_function_and_unflatten_output(
|
||||
example_value=flat_example_value,
|
||||
)
|
||||
|
||||
if ret_spec.masks_to_filter_const_values:
|
||||
from torch._dynamo.external_utils import insert_const_values_with_mask
|
||||
|
||||
# During flattening, we removed the constant values. To ensure Dynamo
|
||||
# can trace correctly, insert back the constant values in the output.
|
||||
flat_variable = _make_inlined(tx, insert_const_values_with_mask)(
|
||||
flat_variable, ret_spec.masks_to_filter_const_values, ret_spec.const_values
|
||||
)
|
||||
|
||||
# Transform variable back into a list (previously made into a tuple by
|
||||
# speculate_subgraph function) so as to respect the pytree API typing.
|
||||
flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {})
|
||||
@ -646,6 +673,9 @@ def speculate_subgraph(
|
||||
set_subgraph_inputs="automatic",
|
||||
restore_side_effects=True,
|
||||
should_flatten_outputs=False,
|
||||
# if should_flatten_outputs is True, `remove_consts_from_outputs` remove the
|
||||
# const outputs from the subgraph output.
|
||||
remove_consts_from_outputs=True,
|
||||
under_activation_checkpoint=False,
|
||||
# TODO - supports input_mutation and aliasing should be False by default for strictness
|
||||
supports_input_mutation=True,
|
||||
@ -736,15 +766,38 @@ def speculate_subgraph(
|
||||
tx.output.side_effects = prev_side_effects
|
||||
|
||||
treespec = None
|
||||
masks_to_filter_const_values = None
|
||||
const_values = None
|
||||
if should_flatten_outputs:
|
||||
from torch._dynamo.external_utils import filter_out_const_values
|
||||
|
||||
# Flatten the speculated subgraph output.
|
||||
output, treespec = _make_inlined(tx, pytree.tree_flatten)(
|
||||
output
|
||||
).unpack_var_sequence(tx)
|
||||
|
||||
# Actually, transform the list (returned by flatten) into a tuple
|
||||
# for dynamo consistency.
|
||||
output = BuiltinVariable(tuple).call_function(tx, [output], {})
|
||||
|
||||
if remove_consts_from_outputs:
|
||||
# Filter out the constants and save them into a spec. Filtering
|
||||
# out constants makes the graph simpler for the backends. We
|
||||
# need to ensure that after unflattening the constants are
|
||||
# inserted back at the right positions for the Dynamo tracing to
|
||||
# continue. This is done by filter_const_spec
|
||||
output_proxies = output.as_proxy()
|
||||
masks_to_filter_const_values = pytree.tree_map(
|
||||
lambda x: not isinstance(x, torch.fx.Proxy), output_proxies
|
||||
)
|
||||
const_values = pytree.tree_map(
|
||||
lambda x: None if isinstance(x, torch.fx.Proxy) else x,
|
||||
output_proxies,
|
||||
)
|
||||
output = _make_inlined(tx, filter_out_const_values)(
|
||||
output, masks_to_filter_const_values
|
||||
)
|
||||
|
||||
# Register output to graph
|
||||
# Modeled off of compile_and_call_fx_graph
|
||||
# TODO: support pytree output
|
||||
@ -753,7 +806,12 @@ def speculate_subgraph(
|
||||
if always_restore:
|
||||
# Nothing left to do here
|
||||
return (
|
||||
(output, OutputSpec(treespec)),
|
||||
(
|
||||
output,
|
||||
OutputSpec(
|
||||
treespec, masks_to_filter_const_values, const_values
|
||||
),
|
||||
),
|
||||
tx.output.graph,
|
||||
subtracer.lifted_freevars,
|
||||
)
|
||||
@ -872,7 +930,12 @@ def speculate_subgraph(
|
||||
)
|
||||
|
||||
return (
|
||||
(output, OutputSpec(treespec)),
|
||||
(
|
||||
output,
|
||||
OutputSpec(
|
||||
treespec, masks_to_filter_const_values, const_values
|
||||
),
|
||||
),
|
||||
graph,
|
||||
lifted_freevars,
|
||||
)
|
||||
@ -1070,6 +1133,8 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
"cond",
|
||||
source_target=self.value,
|
||||
should_flatten_outputs=True,
|
||||
# TODO - removing consts from control flow ops need more work
|
||||
remove_consts_from_outputs=False,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
@ -1381,6 +1446,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
# TODO - removing consts from control flow ops need more work
|
||||
remove_consts_from_outputs=False,
|
||||
supports_input_mutation=False,
|
||||
supports_aliasing=False,
|
||||
)
|
||||
@ -1926,6 +1993,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
# TODO - removing consts from control flow ops need more work
|
||||
remove_consts_from_outputs=False,
|
||||
supports_input_mutation=self.supports_input_mutation,
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
@ -2479,8 +2548,6 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
||||
from torch._higher_order_ops.wrap import TagActivationCheckpoint
|
||||
from torch.utils.checkpoint import noop_context_fn
|
||||
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
context_fn = None
|
||||
if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn:
|
||||
ctx = kwargs.pop("context_fn")
|
||||
@ -2520,27 +2587,15 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
||||
|
||||
_, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs)
|
||||
|
||||
# Store the invocation as a call
|
||||
variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=tuple(p_args),
|
||||
kwargs=checkpoint_kwargs,
|
||||
),
|
||||
example_value=example_value,
|
||||
return _call_function_and_unflatten_output(
|
||||
tx,
|
||||
self.value,
|
||||
p_args,
|
||||
checkpoint_kwargs,
|
||||
example_value,
|
||||
out_spec,
|
||||
)
|
||||
|
||||
if out_spec is None:
|
||||
return variable
|
||||
|
||||
# Transform variable back into a list (previously made into a tuple by
|
||||
# speculate_subgraph function) so as to respect the pytree API typing.
|
||||
variable = BuiltinVariable(list).call_function(tx, [variable], {})
|
||||
|
||||
return _make_inlined(tx, pytree.tree_unflatten)(variable, out_spec.treespec)
|
||||
|
||||
|
||||
class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
|
||||
def __init__(self, hop, source) -> None:
|
||||
@ -2552,8 +2607,6 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
func_var = args[0]
|
||||
|
||||
if isinstance(func_var, torch._dynamo.variables.UserFunctionVariable):
|
||||
@ -2571,7 +2624,7 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
|
||||
_,
|
||||
example_value,
|
||||
_body_r,
|
||||
treespec,
|
||||
out_spec,
|
||||
gmod,
|
||||
_,
|
||||
) = self.create_wrapped_node(
|
||||
@ -2587,27 +2640,15 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
|
||||
gmod_meta_key = "_dynamo_bypassing_wrapper_fn"
|
||||
gmod.meta[gmod_meta_key] = func
|
||||
|
||||
# Store the invocation as a call
|
||||
variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=(gmod_meta_key,) + tuple(p_args),
|
||||
kwargs={},
|
||||
),
|
||||
example_value=example_value,
|
||||
return _call_function_and_unflatten_output(
|
||||
tx,
|
||||
self.value,
|
||||
(gmod_meta_key,) + tuple(p_args),
|
||||
{},
|
||||
example_value,
|
||||
out_spec,
|
||||
)
|
||||
|
||||
if treespec is None:
|
||||
return variable
|
||||
|
||||
# Transform variable back into a list (previously made into a tuple by
|
||||
# speculate_subgraph function) so as to respect the pytree API typing.
|
||||
variable = BuiltinVariable(list).call_function(tx, [variable], {})
|
||||
|
||||
return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec.treespec)
|
||||
|
||||
|
||||
class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def call_function(
|
||||
|
@ -45,7 +45,7 @@ invoke_subgraph_counter = 0
|
||||
@dataclass
|
||||
class OutputMetadata:
|
||||
num_fw_outs: Optional[int] = None
|
||||
indexes_with_none: set[int] = field(default_factory=set)
|
||||
indexes_with_symint: set[int] = field(default_factory=set)
|
||||
indexes_with_no_grad: set[int] = field(default_factory=set)
|
||||
|
||||
|
||||
@ -258,8 +258,8 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
||||
|
||||
output_metadata.num_fw_outs = num_fw_outs
|
||||
for idx, fw_out in enumerate(fw_outs):
|
||||
if fw_out is None:
|
||||
output_metadata.indexes_with_none.add(idx)
|
||||
if isinstance(fw_out, torch.SymInt):
|
||||
output_metadata.indexes_with_symint.add(idx)
|
||||
elif not fw_out.requires_grad:
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
|
||||
@ -331,8 +331,8 @@ def get_output_metadata(subgraph, *operands):
|
||||
|
||||
output_metadata.num_fw_outs = num_fw_outs
|
||||
for idx, fw_out in enumerate(fw_outs):
|
||||
if fw_out is None:
|
||||
output_metadata.indexes_with_none.add(idx)
|
||||
if isinstance(fw_out, torch.SymInt):
|
||||
output_metadata.indexes_with_symint.add(idx)
|
||||
elif not fw_out.requires_grad:
|
||||
output_metadata.indexes_with_no_grad.add(idx)
|
||||
return output_metadata
|
||||
@ -428,10 +428,10 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
*operands,
|
||||
)
|
||||
|
||||
# Check that None is at expected indexes.
|
||||
# Check that int (coming from symint) is at expected indexes.
|
||||
for idx, o in enumerate(out):
|
||||
if o is None:
|
||||
assert idx in output_metadata.indexes_with_none
|
||||
if isinstance(o, int):
|
||||
assert idx in output_metadata.indexes_with_symint
|
||||
|
||||
return out
|
||||
|
||||
@ -452,7 +452,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
filtered_grad_outs = []
|
||||
for idx, o in enumerate(grad_outs):
|
||||
if o is None:
|
||||
assert idx in output_metadata.indexes_with_none
|
||||
assert idx in output_metadata.indexes_with_symint
|
||||
elif idx in output_metadata.indexes_with_no_grad:
|
||||
# Deliberately skip over the grad_outs which we know should be
|
||||
# None because the corresponding fwd_out does not require_grad.
|
||||
|
Reference in New Issue
Block a user