mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765 Approved by: https://github.com/ezyang
502 lines
17 KiB
Python
502 lines
17 KiB
Python
# mypy: ignore-errors
|
|
|
|
import copy
|
|
import math
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from functools import partial, wraps
|
|
from typing import Callable, List
|
|
|
|
import torch
|
|
import torch.fx as fx
|
|
from torch.hub import tqdm
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._content_store import ContentStoreWriter
|
|
|
|
from .compile_utils import get_outputs, get_placeholders
|
|
|
|
|
|
is_tuple = object()
|
|
|
|
|
|
@dataclass
|
|
class LoadTensorMeta:
|
|
size: List[int]
|
|
stride: List[int]
|
|
dtype: torch.dtype
|
|
device: torch.device
|
|
|
|
|
|
class ConcreteProp(torch.fx.Interpreter):
|
|
def __init__(self, mod, *, writer=None, skip_offload=False):
|
|
super().__init__(mod)
|
|
self.writer = writer
|
|
self.skip_offload = skip_offload
|
|
self.seen_storages = set()
|
|
|
|
def run_node(self, n):
|
|
self.pbar.update(1)
|
|
r = super().run_node(n)
|
|
name = n.name
|
|
|
|
if isinstance(r, torch.Tensor):
|
|
if self.writer is None:
|
|
n.meta["concrete_value"] = r
|
|
else:
|
|
if StorageWeakRef(r.untyped_storage()) in self.seen_storages:
|
|
# Refuse to offload tensors which alias other live
|
|
# tensors, because this will violate operator contracts
|
|
n.meta["concrete_value"] = None
|
|
else:
|
|
if not self.skip_offload:
|
|
self.writer.write_tensor(os.path.join("eager", name), r)
|
|
n.meta["concrete_value"] = LoadTensorMeta(
|
|
r.size(), r.stride(), r.dtype, r.device
|
|
)
|
|
self.seen_storages.add(StorageWeakRef(r.untyped_storage()))
|
|
else:
|
|
n.meta["concrete_value"] = is_tuple
|
|
|
|
return r
|
|
|
|
def propagate(self, *args):
|
|
with tqdm(
|
|
desc="Saving intermediates for delta debugging",
|
|
total=len(self.module.graph.nodes),
|
|
disable=self.writer is None,
|
|
) as pbar:
|
|
self.pbar = pbar
|
|
r = super().run(*args)
|
|
if not self.skip_offload:
|
|
pbar.set_description(
|
|
"Saved! To skip next time, run with --skip-saving-eager-intermediates"
|
|
)
|
|
return r
|
|
|
|
|
|
def is_load_tensor_node(node):
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target is torch.ops.debugprims.load_tensor.default
|
|
)
|
|
|
|
|
|
# inplace modifies node/inps
|
|
def _convert_node_to_placeholder(graph, node, inps):
|
|
if node.op == "output" or node.op == "placeholder":
|
|
return False
|
|
|
|
if is_load_tensor_node(node):
|
|
return False
|
|
|
|
concrete_val = node.meta.get("concrete_value", None)
|
|
|
|
if isinstance(concrete_val, torch.Tensor):
|
|
node.op = "placeholder"
|
|
node.target = node.name
|
|
node.args = ()
|
|
node.kwargs = {}
|
|
|
|
inps.append(concrete_val)
|
|
return True
|
|
|
|
elif concrete_val is None:
|
|
return False
|
|
|
|
elif concrete_val is is_tuple:
|
|
r = False
|
|
for tuple_user in list(node.users):
|
|
r = _convert_node_to_placeholder(graph, tuple_user, inps) or r
|
|
# NB: We must not erase the node at this point, because
|
|
# we are iterating over the nodes and this would change
|
|
# the iteration order
|
|
# graph.erase_node(node)
|
|
return r
|
|
|
|
elif isinstance(concrete_val, LoadTensorMeta):
|
|
node.op = "call_function"
|
|
node.target = torch.ops.debugprims.load_tensor.default
|
|
node.args = (
|
|
os.path.join("eager", node.name),
|
|
concrete_val.size,
|
|
concrete_val.stride,
|
|
)
|
|
node.kwargs = {
|
|
"device": concrete_val.device,
|
|
"dtype": concrete_val.dtype,
|
|
}
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def create_minified_hlo_graph(minified_fx_graph, inputs):
|
|
"""
|
|
Takes minified FX graph as primary input, and ports it to HLO via StableHLO
|
|
Provides minified HLO graph as output, and archive them to local directory
|
|
"""
|
|
hlo_dir = f"{os.getcwd()}/hlo_files"
|
|
os.makedirs(hlo_dir, exists_ok=True)
|
|
|
|
from torch_xla.stablehlo import save_torch_model_as_stablehlo
|
|
|
|
save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir)
|
|
|
|
|
|
def dump_state(fx_g, inps):
|
|
print(
|
|
f"""
|
|
# Working Repro with {len(fx_g.graph.nodes)} nodes
|
|
inps = {[(i.shape, i.dtype, i.device.type) for i in inps]}
|
|
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps]
|
|
{fx_g.code}
|
|
"""
|
|
)
|
|
|
|
|
|
def is_power_of_two(n):
|
|
if n == 0:
|
|
return False
|
|
return (n & (n - 1)) == 0
|
|
|
|
|
|
@dataclass
|
|
class ReproState:
|
|
graph: fx.Graph
|
|
inps: List[torch.Tensor]
|
|
|
|
def __post_init__(self):
|
|
ph_nodes = get_placeholders(self.graph)
|
|
assert len(ph_nodes) == len(self.inps)
|
|
|
|
|
|
def minifier(
|
|
fail_f: fx.GraphModule,
|
|
inps,
|
|
module_fails,
|
|
dump_state: Callable = dump_state,
|
|
*,
|
|
save_dir=None,
|
|
offload_to_disk=False,
|
|
skip_offload=False,
|
|
skip_sanity=False,
|
|
max_granularity=None,
|
|
):
|
|
"""
|
|
Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
|
|
|
|
Does 2 main strategies:
|
|
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
|
|
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
|
|
tries replacing quarter of the graph, etc.
|
|
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> failing_function = fx.symbolic_trace(f)
|
|
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
|
|
|
|
note: module_fails returns True if it fails.
|
|
"""
|
|
assert isinstance(inps, (tuple, list))
|
|
|
|
failing_graph = fail_f.graph
|
|
cur_size = len(failing_graph.nodes)
|
|
|
|
if max_granularity is not None and not is_power_of_two(max_granularity):
|
|
raise RuntimeError(f"max_granularity {max_granularity} not power of two")
|
|
|
|
num_queries = 0
|
|
|
|
def deepcopy_fx_graph(fx_graph):
|
|
return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
|
|
|
|
def graph_fails(graph, inps):
|
|
nonlocal num_queries
|
|
graph = copy.deepcopy(graph)
|
|
num_queries += 1
|
|
mod = fx.GraphModule(fail_f, graph)
|
|
mod.graph.lint()
|
|
return module_fails(mod, inps)
|
|
|
|
writer = None
|
|
if offload_to_disk:
|
|
writer = ContentStoreWriter(save_dir)
|
|
|
|
ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps)
|
|
if not skip_sanity and not graph_fails(failing_graph, inps):
|
|
raise RuntimeError("Input graph did not fail the tester")
|
|
print(f"Started off with {cur_size} nodes", file=sys.stderr)
|
|
|
|
def _register_strategy(strategy: Callable, name: str):
|
|
@wraps(strategy)
|
|
def new_func(old_state: ReproState, granularity=1):
|
|
print(file=sys.stderr)
|
|
print(
|
|
f"Strategy: {name} (G: {granularity}) "
|
|
f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)",
|
|
file=sys.stderr,
|
|
)
|
|
new_state = strategy(
|
|
deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity
|
|
)
|
|
if new_state is not None:
|
|
new_nodes = len(new_state.graph.nodes)
|
|
old_nodes = len(old_state.graph.nodes)
|
|
new_inps = len(new_state.inps)
|
|
old_inps = len(old_state.inps)
|
|
new_outs = len(get_outputs(new_state.graph))
|
|
old_outs = len(get_outputs(old_state.graph))
|
|
progress_made = False
|
|
if new_nodes < old_nodes:
|
|
progress_made = True
|
|
print(
|
|
f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes",
|
|
file=sys.stderr,
|
|
)
|
|
if new_inps > old_inps:
|
|
progress_made = True
|
|
print(
|
|
f"SUCCESS: Went from {old_inps} to {new_inps} inputs",
|
|
file=sys.stderr,
|
|
)
|
|
if new_outs < old_outs:
|
|
progress_made = True
|
|
print(
|
|
f"SUCCESS: Went from {old_outs} to {new_outs} outputs",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
if not progress_made:
|
|
raise RuntimeError("Success raised but no progress made?")
|
|
|
|
if not graph_fails(new_state.graph, new_state.inps):
|
|
print(
|
|
"WARNING: Something went wrong, not applying this minification",
|
|
file=sys.stderr,
|
|
)
|
|
return None
|
|
return new_state
|
|
else:
|
|
print(f"FAIL: {name}", file=sys.stderr)
|
|
return None
|
|
|
|
return new_func
|
|
|
|
def register_strategy(name: str):
|
|
return partial(_register_strategy, name=name)
|
|
|
|
@register_strategy("Truncate suffix")
|
|
def remove_suffix(cur_graph, cur_inps, granularity):
|
|
tested = set()
|
|
new_graph = fx.Graph()
|
|
env = {}
|
|
for idx, node in enumerate(cur_graph.nodes):
|
|
new_node = new_graph.node_copy(node, lambda x: env[x])
|
|
if node.op not in ["placeholder", "output"]:
|
|
# If idx is divisible by (granularity * 2), it would have been checked already.
|
|
if (
|
|
idx % granularity == 0
|
|
and (idx % (granularity * 2) != 0)
|
|
and idx not in tested
|
|
):
|
|
output_node = new_graph.output((new_node,))
|
|
if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(
|
|
new_graph, cur_inps
|
|
):
|
|
return ReproState(new_graph, cur_inps)
|
|
else:
|
|
tested.add(idx)
|
|
new_graph.erase_node(output_node)
|
|
env[node] = new_node
|
|
return None
|
|
|
|
@register_strategy("Remove outputs")
|
|
def remove_outputs(cur_graph, cur_inps, granularity):
|
|
granularity = max(1, granularity // 2)
|
|
for idx, node in enumerate(cur_graph.nodes):
|
|
node.idx = idx
|
|
if node.op == "output":
|
|
output = node
|
|
break
|
|
|
|
if isinstance(output.args[0], fx.Node):
|
|
return None
|
|
|
|
output_args = sorted(
|
|
output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9)
|
|
)
|
|
if len(output_args) == 1:
|
|
return None
|
|
|
|
for idx in range(0, len(output_args), granularity):
|
|
output.args = (output_args[:idx] + output_args[idx + granularity :],)
|
|
if graph_fails(cur_graph, cur_inps):
|
|
return ReproState(cur_graph, cur_inps)
|
|
return None
|
|
|
|
def remove_unused_inputs_unchecked(cur_state: ReproState):
|
|
cur_graph = cur_state.graph
|
|
cur_inps = cur_state.inps
|
|
ph_nodes = get_placeholders(cur_graph)
|
|
assert len(ph_nodes) == len(cur_inps)
|
|
|
|
new_inps = []
|
|
for idx in range(len(ph_nodes)):
|
|
if len(ph_nodes[idx].users) == 0:
|
|
cur_graph.erase_node(ph_nodes[idx])
|
|
else:
|
|
new_inps.append(cur_inps[idx])
|
|
if len(new_inps) < len(cur_inps):
|
|
return ReproState(cur_graph, new_inps)
|
|
return None
|
|
|
|
def remove_unused_inputs_checked(cur_state: ReproState):
|
|
new_state = remove_unused_inputs_unchecked(cur_state)
|
|
if new_state is not None and graph_fails(new_state.graph, new_state.inps):
|
|
return new_state
|
|
return None
|
|
|
|
def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
|
|
return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))
|
|
|
|
remove_unused_inputs = register_strategy("Remove unused inputs")(
|
|
_remove_unused_wrapper
|
|
)
|
|
|
|
@register_strategy("Eliminate dead code")
|
|
def eliminate_dead_code(cur_graph, cur_inps, granularity):
|
|
if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
|
|
return ReproState(cur_graph, cur_inps)
|
|
return None
|
|
|
|
def _consolidate_placeholders(cur_graph, inps):
|
|
new_graph = fx.Graph()
|
|
env = {}
|
|
seen_non_placeholder = False
|
|
|
|
# Move all placeholders to the front; also, if any load_tensor
|
|
# is at the front, convert it into an input (because it can be live
|
|
# all the time)
|
|
for node in cur_graph.nodes:
|
|
if node.op == "placeholder":
|
|
new_node = new_graph.node_copy(node, lambda x: env[x])
|
|
env[node] = new_node
|
|
elif not seen_non_placeholder and is_load_tensor_node(node):
|
|
new_node = new_graph.placeholder(node.name)
|
|
env[node] = new_node
|
|
inps.append(
|
|
torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs)
|
|
)
|
|
else:
|
|
seen_non_placeholder = True
|
|
|
|
# Move everyone else
|
|
for node in cur_graph.nodes:
|
|
if node not in env:
|
|
new_node = new_graph.node_copy(node, lambda x: env[x])
|
|
env[node] = new_node
|
|
return new_graph
|
|
|
|
@register_strategy("Delta Debugging")
|
|
def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
|
|
num_nodes = len(cur_graph.nodes)
|
|
for start_range in range(0, num_nodes, granularity):
|
|
is_removing = False
|
|
new_graph = deepcopy_fx_graph(cur_graph)
|
|
new_inps = cur_inps[:]
|
|
end_range = min(num_nodes, start_range + granularity)
|
|
for idx in range(start_range, end_range):
|
|
new_node = list(new_graph.nodes)[idx]
|
|
if _convert_node_to_placeholder(new_graph, new_node, new_inps):
|
|
is_removing = True
|
|
if not is_removing:
|
|
continue
|
|
new_graph.eliminate_dead_code()
|
|
new_graph = _consolidate_placeholders(new_graph, new_inps)
|
|
new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps))
|
|
if new_state is None:
|
|
new_state = ReproState(new_graph, new_inps)
|
|
if graph_fails(new_state.graph, new_state.inps):
|
|
return ReproState(new_state.graph, new_state.inps)
|
|
|
|
return None
|
|
|
|
@register_strategy("Consolidate Inputs")
|
|
def consolidate_inputs(cur_graph, cur_inps, granularity):
|
|
old_len = len(cur_inps)
|
|
cur_graph = _consolidate_placeholders(cur_graph, cur_inps)
|
|
if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps):
|
|
return ReproState(cur_graph, cur_inps)
|
|
return None
|
|
|
|
failing_state = ReproState(failing_graph, inps)
|
|
|
|
def try_granularity(failing_state, granularity, use_non_granular):
|
|
print(f"Trying granularity {granularity}", file=sys.stderr)
|
|
|
|
strategies = []
|
|
num_nodes = len(failing_state.graph.nodes)
|
|
num_outputs = len(get_outputs(failing_state.graph))
|
|
if num_outputs > num_nodes // 2:
|
|
strategies += [remove_outputs]
|
|
|
|
if use_non_granular:
|
|
strategies += [
|
|
eliminate_dead_code,
|
|
remove_unused_inputs,
|
|
consolidate_inputs,
|
|
]
|
|
|
|
strategies += [remove_suffix, delta_debugging]
|
|
|
|
for strategy in strategies:
|
|
new_state = strategy(failing_state, granularity)
|
|
if new_state is not None:
|
|
return new_state
|
|
return None
|
|
|
|
while True:
|
|
dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps)
|
|
granularity = int(2 ** (math.floor(math.log2(len(failing_state.graph.nodes)))))
|
|
if max_granularity is not None:
|
|
granularity = min(max_granularity, granularity)
|
|
new_state = try_granularity(failing_state, granularity, use_non_granular=True)
|
|
if new_state is not None:
|
|
failing_state = new_state
|
|
continue
|
|
|
|
granularity //= 2
|
|
has_progress = False
|
|
while granularity >= 1:
|
|
new_state = try_granularity(
|
|
failing_state, granularity, use_non_granular=False
|
|
)
|
|
if new_state is not None:
|
|
failing_state = new_state
|
|
has_progress = True
|
|
break
|
|
granularity //= 2
|
|
if has_progress:
|
|
continue
|
|
|
|
new_state = remove_outputs(failing_state, 1)
|
|
if new_state is not None:
|
|
failing_state = new_state
|
|
continue
|
|
|
|
break
|
|
|
|
if not graph_fails(failing_state.graph, failing_state.inps):
|
|
raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
|
|
|
|
print(f"Made {num_queries} queries", file=sys.stderr)
|
|
failing_fx = fx.GraphModule(fail_f, failing_state.graph)
|
|
|
|
# If XLA debugging environment is enabled, create minified HLO graph as well
|
|
if "XLA_HLO_DEBUG" in os.environ:
|
|
create_minified_hlo_graph(failing_fx, failing_state.inps)
|
|
|
|
dump_state(failing_fx, failing_state.inps)
|
|
print("Wrote minimal repro out to repro.py", file=sys.stderr)
|
|
return failing_fx, failing_state.inps
|