Compare commits

...

1 Commits

Author SHA1 Message Date
ba81f47b13 Initial commit 2024-12-17 17:04:35 -08:00
8 changed files with 78 additions and 53 deletions

View File

@ -84,7 +84,7 @@ assume_static_by_default = True
# with assume_static_by_default=True.
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
automatic_dynamic_shapes = True
automatic_dynamic_shapes = False
# Valid options: "dynamic", "unbacked"
automatic_dynamic_shapes_mark_as = "dynamic"
@ -398,7 +398,7 @@ enable_cpp_framelocals_guard_eval = True
# Whether to automatically find and replace identical graph
# regions with a call to invoke_subgraph
use_graph_deduplication = False
use_graph_deduplication = True
# Whether to track nodes for deduplication (testing only)
# This flag is ignored if use_graph_deduplication is True

View File

@ -3,7 +3,6 @@ import operator
from typing import Any, Dict, Iterable, List, Set, Tuple
import torch.fx
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.utils._pytree import tree_flatten
from .graph_region_tracker import Node, Region
@ -76,6 +75,25 @@ when they are created in output_graph.
return output_replacements
def _flatten_args_kwargs(args) -> List[Node]:
fully_flattened = []
def flatten(args):
flattened, _ = tree_flatten(args)
for arg in flattened:
if isinstance(arg, slice):
start = arg.start
stop = arg.stop
step = arg.step
flatten((start, stop, step))
else:
fully_flattened.append(arg)
flatten(args)
return fully_flattened
def _replace_region_with_subgraph(
graph: torch.fx.Graph,
region: Region,
@ -89,18 +107,20 @@ def _replace_region_with_subgraph(
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
node = region[node_ind]
flattened_args_kwargs, _ = tree_flatten((node.args, node.kwargs))
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
sub_args.append(flattened_args_kwargs[arg_ind])
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
region,
)
return
# print(fake_inputs)
# breakpoint()
# if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
# log.debug(
# "NYI: Failed to substitute region %s due to input alias or mutation",
# region,
# )
# return
latest_region_node = region[-1]
with graph.inserting_after(latest_region_node):
@ -127,12 +147,15 @@ def _get_external_inputs(
external_node_to_indices = dict()
region_unique = set(region)
for node_ind, node in enumerate(region):
flattened_args_kwargs, _ = tree_flatten((node.args, node.kwargs))
# if node.name == "rope_cache_4":
# breakpoint()
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
for arg_ind, in_node in enumerate(flattened_args_kwargs):
if (
in_node not in region_unique
isinstance(in_node, Node)
and in_node not in region_unique
and in_node not in external_node_to_indices
and isinstance(in_node, Node)
):
external_node_to_indices[in_node] = (node_ind, arg_ind)

View File

@ -258,10 +258,13 @@ class GraphRegionTracker:
region_group = []
min_rank = math.inf
for node in group:
min_rank = min(min_rank, topological_ranking[node])
region_group.append([node])
# some nodes aren't in the topo ranking?
if node in topological_ranking:
min_rank = min(min_rank, topological_ranking[node])
region_group.append([node])
region_groups_with_rank.append((region_group, min_rank))
if len(region_group) > 1:
region_groups_with_rank.append((region_group, min_rank))
region_groups_with_rank.sort(key=lambda rg: -rg[1])
region_groups = [rg for rg, _ in region_groups_with_rank]
@ -273,6 +276,10 @@ class GraphRegionTracker:
for region_group in region_groups:
fully_expand_region_group(region_group, seen_nodes, self._is_identical)
for rg in region_groups:
for r in rg:
r.sort(key=lambda n: topological_ranking[n])
return [
region_group for region_group in region_groups if len(region_group[0]) > 1
]
@ -305,7 +312,7 @@ def fully_expand_region_group(
region_it.add_children(node)
current_node = region_iters[0].next()
assert current_node is not None
# Loop incrementally adding new nodes to each region
# regions are only expanded if the node to add is valid
# for ALL regions
@ -325,6 +332,8 @@ def fully_expand_region_group(
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
node not in seen_nodes
and node.op != "placeholder"
# and node.target != "size"
and node not in nodes_to_add_set
and is_identical_fn(node, current_node)
)

View File

@ -99,7 +99,6 @@ from .utils import (
get_static_address_type,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
LazyString,
nn_module_proxy,
same,
@ -1350,12 +1349,6 @@ class OutputGraph:
] = self.dynamo_flat_name_to_original_fqn.copy()
gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
graph_code_log.debug(
"%s",
lazy_format_graph_code(
name, gm, include_stride=True, include_device=True, colored=True
),
)
torch._logging.trace_structured(
"dynamo_output_graph",
lambda: {"sizes": self.get_graph_sizes_structured()},

View File

@ -148,15 +148,15 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
# Infer grad_outputs to be the same properties as the fw_outputs
# if they're not passed in.
grad_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
if any(
not isinstance(out, torch.Tensor)
for out in grad_outputs
if out is not None
):
raise RuntimeError(
"Expect outputs of invoke_subgraph to only contains tensors or None. "
f"Got types {[type(out) for out in grad_outputs]}."
)
# if any(
# not isinstance(out, torch.Tensor)
# for out in grad_outputs
# if out is not None
# ):
# raise RuntimeError(
# "Expect outputs of invoke_subgraph to only contains tensors or None. "
# f"Got types {[type(out) for out in grad_outputs]}."
# )
# Trace the forward subgraph
fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs)

View File

@ -91,7 +91,7 @@ autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default()
# Force disabled all inductor level caching -- This will override any other caching flag
force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
force_disable_caches = True
# sleep in inductor for testing
sleep_sec_TESTING_ONLY: Optional[int] = None

View File

@ -3859,6 +3859,9 @@ class ConstantBuffer(InputBuffer):
@ir_dataclass
class NoneAsConstantBuffer(IRNode):
def get_reads(self):
return OrderedSet()
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
@ -7137,21 +7140,23 @@ class InvokeSubgraph(ExternKernel):
layout=MultiOutputLayout(device=device),
)
outputs = [
MultiOutput(
FixedLayout(
device=output.get_device(),
dtype=output.get_dtype(),
size=output.get_size(), # type: ignore[arg-type]
stride=output.get_stride(),
offset=output.get_layout().offset,
),
invoke_subgraph,
[(list, i)],
)
for i, output in enumerate(outputs)
]
def create_output(output: IRNode, ind: int):
if isinstance(output, NoneAsConstantBuffer):
return output
else:
return MultiOutput(
FixedLayout(
device=output.get_device(),
dtype=output.get_dtype(),
size=output.get_size(), # type: ignore[arg-type]
stride=output.get_stride(),
offset=output.get_layout().offset,
),
invoke_subgraph,
[(list, ind)],
)
outputs = [create_output(output, i) for i, output in enumerate(outputs)]
invoke_subgraph.outputs = outputs
return outputs

View File

@ -15,7 +15,6 @@ from torch._prims_common import get_computation_dtype
from torch._subclasses import fake_tensor # noqa: TCH001
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import justknobs_check
from torch.fx._utils import lazy_format_graph_code
from torch.fx.experimental.symbolic_shapes import guard_scalar, ShapeEnv # noqa: TCH001
from torch.fx.graph_module import GraphModule # noqa: TCH001
@ -340,7 +339,3 @@ def tensorify_python_scalars(
# are no longer needed and should be specialized. Restarting analysis is necessary
# because we need to instruct Dynamo to NOT make these as inputs.
raise TensorifyScalarRestartAnalysis
graph_code_log.debug(
"%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True)
)