mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR: 1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.** We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors). 2. **We cache the bound_symbols** to avoid lift the same symbol repeated. 3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part). 4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop. 5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops. **The interaction of nested tracers:** The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling]. Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time. For example, suppose we have the following function: ```python def f(x: [s1, s2]): def true_f(): def true_f_inner(): return x.sin() ``` what will happen in time order: 1. we create a subtracer 1 and start to speculate the outer cond's true_f 2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner. 3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like: ```python def gm(s1, s2, x): ``` 4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(x): ``` 5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like ```python def gm(s1, s2, x): def true_gm(s1, s2, x): ``` 6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1, s2, x): ``` 7. Finally the sin call_function node is created by tracer 2. **This PR also handles the following cases:** - What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created. - what if a subgraph close over a symint? e.g. ```python def f(x): def true_f(): c = x.size(0) def true_fn_inner(): return c ``` When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(): ``` So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like: ```python def gm(s1, s2, x): def true_gm(s1, s2, x): def true_gm_inner(s1): return s1 ``` - What if subgraph close over an unbacked symint? e.g. ```python def f(x): def true_f(): c = x.item() def true_f_inner(): return c ``` When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like: ```python def f(x): def true_f(s1, s2, x): c = x.item() def true_gm_inner(u0): return u0 cond(pred, true_gm_inner, false_gm_inner, (c,)) ``` - what if subgraph close over a tensor with unbacked symint shape? ```python def f(x): def true_f(): c = x.item() r = torch.randn((c,)) def true_f_inner(): return r + 1 ``` This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363 Approved by: https://github.com/zou3519
492 lines
14 KiB
Python
492 lines
14 KiB
Python
import contextlib
|
|
import dis
|
|
import functools
|
|
import logging
|
|
import os.path
|
|
import random
|
|
import re
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
overload,
|
|
Sequence,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from torch import fx
|
|
from torch._dynamo.backends.debugging import aot_eager
|
|
from torch._dynamo.output_graph import OutputGraph
|
|
|
|
from . import config, eval_frame, optimize_assert, reset
|
|
from .bytecode_transformation import (
|
|
create_instruction,
|
|
debug_checks,
|
|
is_generator,
|
|
transform_code_object,
|
|
)
|
|
from .guards import CheckFunctionManager, CompileId, GuardedCode
|
|
from .utils import same
|
|
|
|
|
|
np: Optional[types.ModuleType] = None
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None
|
|
|
|
|
|
unsupported = eval_frame.unsupported
|
|
three = 3
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
if x is None:
|
|
return None
|
|
return x.detach().clone().requires_grad_(x.requires_grad)
|
|
|
|
|
|
def remove_optimized_module_prefix(name: str) -> str:
|
|
return re.sub(r"^_orig_mod[.]", "", name)
|
|
|
|
|
|
def collect_results(
|
|
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
|
) -> List[Any]:
|
|
results = []
|
|
results.append(prediction)
|
|
results.append(loss)
|
|
# if isinstance(loss, torch.Tensor) and loss.item() > 1:
|
|
# log.warning(
|
|
# f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
|
|
# )
|
|
|
|
grads = {}
|
|
params = {}
|
|
for name, param in model.named_parameters():
|
|
if isinstance(model, eval_frame.OptimizedModule):
|
|
name = remove_optimized_module_prefix(name)
|
|
param_copy = param
|
|
grad = param.grad
|
|
# Treat None and zero grad as same
|
|
if param.grad is None:
|
|
grad = torch.zeros_like(param)
|
|
grads[name + ".grad"] = grad
|
|
params[name] = param_copy
|
|
results.append(grads)
|
|
results.append(params)
|
|
buffers = {}
|
|
for name, buffer in model.named_buffers():
|
|
if isinstance(model, eval_frame.OptimizedModule):
|
|
name = remove_optimized_module_prefix(name)
|
|
buffers[name] = buffer
|
|
results.append(buffers)
|
|
for example in example_inputs:
|
|
if isinstance(example, (tuple, list)):
|
|
for inp in example:
|
|
if isinstance(inp, torch.Tensor):
|
|
results.append(inp.grad)
|
|
else:
|
|
if isinstance(example, torch.Tensor):
|
|
results.append(example.grad)
|
|
return results
|
|
|
|
|
|
def requires_bwd_pass(out: Any) -> bool:
|
|
if isinstance(out, torch.Tensor):
|
|
return out.requires_grad
|
|
elif isinstance(out, (list, tuple)):
|
|
return any(requires_bwd_pass(x) for x in out)
|
|
elif out is None:
|
|
return False
|
|
elif isinstance(out, int):
|
|
return False
|
|
raise NotImplementedError("Don't know how to reduce", type(out))
|
|
|
|
|
|
@overload
|
|
def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor:
|
|
...
|
|
|
|
|
|
@overload
|
|
def reduce_to_scalar_loss(
|
|
out: Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]
|
|
) -> float:
|
|
...
|
|
|
|
|
|
def reduce_to_scalar_loss(out: Any) -> Union[torch.Tensor, float]:
|
|
"""Reduce the output of a model to get scalar loss"""
|
|
if isinstance(out, torch.Tensor):
|
|
# Mean does not work on integer tensors
|
|
return out.sum() / out.numel()
|
|
elif isinstance(out, (list, tuple)):
|
|
return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
|
|
elif type(out).__name__ in (
|
|
"MaskedLMOutput",
|
|
"Seq2SeqLMOutput",
|
|
"CausalLMOutputWithCrossAttentions",
|
|
):
|
|
return reduce_to_scalar_loss(out.logits)
|
|
elif type(out).__name__ == "SquashedNormal":
|
|
return out.mean.sum()
|
|
elif isinstance(out, dict):
|
|
return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
|
|
out.keys()
|
|
)
|
|
raise NotImplementedError("Don't know how to reduce", type(out))
|
|
|
|
|
|
def debug_dir() -> str:
|
|
path = os.path.join(os.path.dirname(__file__), "../debug")
|
|
if not os.path.exists(path):
|
|
os.mkdir(path)
|
|
return path
|
|
|
|
|
|
def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None:
|
|
with open(os.path.join(debug_dir(), name), "w") as fd:
|
|
fd.write(
|
|
f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
|
|
)
|
|
|
|
|
|
def debug_insert_nops(
|
|
frame: types.FrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
|
|
) -> Optional[GuardedCode]:
|
|
"""used to debug jump updates"""
|
|
|
|
def insert_nops(instructions: List[Any], code_options: Any) -> None:
|
|
instructions.insert(0, create_instruction("NOP"))
|
|
instructions.insert(0, create_instruction("NOP"))
|
|
|
|
if is_generator(frame.f_code):
|
|
return None
|
|
|
|
debug_checks(frame.f_code)
|
|
code = transform_code_object(frame.f_code, insert_nops)
|
|
graph = OutputGraph(
|
|
code_options={},
|
|
compiler_fn=None,
|
|
root_tx=None,
|
|
export=False,
|
|
export_constraints=None,
|
|
frame_state={"_id": 0},
|
|
# TODO: shouldn't this be f_locals/f_globals from frame?
|
|
local_scope=locals(),
|
|
global_scope=globals(),
|
|
f_code=frame.f_code,
|
|
torch_function_mode_stack=[],
|
|
)
|
|
|
|
return GuardedCode(code, CheckFunctionManager(graph).guard_manager, CompileId(0, 0)) # type: ignore[arg-type]
|
|
|
|
|
|
class CompileCounter:
|
|
def __init__(self) -> None:
|
|
self.frame_count = 0
|
|
self.op_count = 0
|
|
|
|
def __call__(
|
|
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
self.frame_count += 1
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
self.op_count += 1
|
|
return gm.forward
|
|
|
|
def clear(self) -> None:
|
|
self.frame_count = 0
|
|
self.op_count = 0
|
|
|
|
|
|
class CompileCounterWithBackend:
|
|
def __init__(self, backend: str) -> None:
|
|
self.frame_count = 0
|
|
self.op_count = 0
|
|
self.backend = backend
|
|
self.graphs: List[torch.fx.GraphModule] = []
|
|
|
|
def __call__(
|
|
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
from .backends.registry import lookup_backend
|
|
|
|
self.frame_count += 1
|
|
for node in gm.graph.nodes:
|
|
if "call" in node.op:
|
|
self.op_count += 1
|
|
self.graphs.append(gm)
|
|
return lookup_backend(self.backend)(gm, example_inputs)
|
|
|
|
|
|
# Equivalent to backend="eager", but also records graphs that
|
|
# we can assert on
|
|
class EagerAndRecordGraphs:
|
|
def __init__(self) -> None:
|
|
self.graphs: List[torch.fx.GraphModule] = []
|
|
|
|
def __call__(
|
|
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
self.graphs.append(gm)
|
|
return gm.forward
|
|
|
|
|
|
class AotEagerAndRecordGraphs:
|
|
def __init__(self) -> None:
|
|
self.graphs: List[torch.fx.GraphModule] = []
|
|
self.fw_graphs: List[torch.fx.GraphModule] = []
|
|
self.bw_graphs: List[torch.fx.GraphModule] = []
|
|
|
|
def __call__(
|
|
self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
self.graphs.append(gm)
|
|
|
|
def fw_compiler(
|
|
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
self.fw_graphs.append(gm)
|
|
return gm.forward
|
|
|
|
def bw_compiler(
|
|
gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
self.bw_graphs.append(gm)
|
|
return gm.forward
|
|
|
|
return aot_eager(
|
|
gm,
|
|
example_inputs,
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
)
|
|
|
|
|
|
def strip_comment(code: str) -> str:
|
|
return re.sub(r"(?m)^ *#.*\n?", "", code)
|
|
|
|
|
|
def remove_trailing_space(code: str) -> str:
|
|
return "\n".join([line.rstrip() for line in code.split("\n")])
|
|
|
|
|
|
def normalize_gm(gm_str: str) -> str:
|
|
# strip comments as comments have path to files which may differ from
|
|
# system to system.
|
|
return remove_trailing_space(strip_comment(gm_str))
|
|
|
|
|
|
def empty_line_normalizer(code: str) -> str:
|
|
"""
|
|
Normalize code: remove empty lines.
|
|
"""
|
|
normal_code = re.sub(r"[\r\n]+", "\n", code)
|
|
return normal_code
|
|
|
|
|
|
def standard_test(
|
|
self: Any,
|
|
fn: Callable[..., Any],
|
|
nargs: int,
|
|
expected_ops: Optional[int] = None,
|
|
expected_ops_dynamic: Optional[int] = None,
|
|
expected_frame_count: int = 1,
|
|
) -> None:
|
|
if not config.assume_static_by_default and expected_ops_dynamic is not None:
|
|
expected_ops = expected_ops_dynamic
|
|
|
|
actual = CompileCounter()
|
|
|
|
args1 = [torch.randn(10, 10) for _ in range(nargs)]
|
|
args2 = [torch.randn(10, 10) for _ in range(nargs)]
|
|
correct1 = fn(*args1)
|
|
correct2 = fn(*args2)
|
|
reset()
|
|
opt_fn = optimize_assert(actual)(fn)
|
|
val1a = opt_fn(*args1)
|
|
val2a = opt_fn(*args2)
|
|
val1b = opt_fn(*args1)
|
|
val2b = opt_fn(*args2)
|
|
reset()
|
|
self.assertTrue(same(val1a, correct1))
|
|
self.assertTrue(same(val1b, correct1))
|
|
self.assertTrue(same(val2a, correct2))
|
|
self.assertTrue(same(val2b, correct2))
|
|
self.assertEqual(actual.frame_count, expected_frame_count)
|
|
if expected_ops is not None:
|
|
self.assertEqual(actual.op_count, expected_ops)
|
|
|
|
|
|
def dummy_fx_compile(
|
|
gm: fx.GraphModule, example_inputs: List[torch.Tensor]
|
|
) -> Callable[..., Any]:
|
|
return gm.forward
|
|
|
|
|
|
def format_speedup(
|
|
speedup: float,
|
|
pvalue: float,
|
|
is_correct: bool = True,
|
|
pvalue_threshold: float = 0.1,
|
|
) -> str:
|
|
if not is_correct:
|
|
return "ERROR"
|
|
if pvalue > pvalue_threshold:
|
|
return f"{speedup:.3f}x SAME"
|
|
return f"{speedup:.3f}x p={pvalue:.2f}"
|
|
|
|
|
|
def rand_strided(
|
|
size: Sequence[int],
|
|
stride: Sequence[int],
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Union[str, torch.device] = "cpu",
|
|
extra_size: int = 0,
|
|
) -> torch.Tensor:
|
|
needed_size = (
|
|
sum((shape - 1) * stride for shape, stride in zip(size, stride))
|
|
+ 1
|
|
+ extra_size
|
|
)
|
|
if dtype.is_floating_point:
|
|
if dtype.itemsize == 1:
|
|
"""
|
|
normal distribution kernel is not implemented for fp8..
|
|
Workaround that by creating a fp16 tensor and then cast.
|
|
"""
|
|
buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to(
|
|
dtype=dtype
|
|
)
|
|
else:
|
|
buffer = torch.randn(needed_size, dtype=dtype, device=device)
|
|
else:
|
|
buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device)
|
|
return torch.as_strided(buffer, size, stride)
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
def check_dynamic_shape_capture() -> bool:
|
|
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
|
|
return not config.assume_static_by_default
|
|
|
|
|
|
def _make_fn_with_patches(fn: Callable[..., _T], *patches: Any) -> Callable[..., _T]:
|
|
@functools.wraps(fn)
|
|
def _fn(*args: Any, **kwargs: Any) -> _T:
|
|
with contextlib.ExitStack() as stack:
|
|
for module, attr, val in patches:
|
|
stack.enter_context(patch.object(module, attr, val))
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
|
|
def make_test_cls_with_patches(
|
|
cls: type,
|
|
cls_prefix: str,
|
|
fn_suffix: str,
|
|
*patches: Any,
|
|
xfail_prop: Optional[str] = None,
|
|
decorator: Callable[[Callable[..., Any]], Callable[..., Any]] = lambda x: x,
|
|
) -> type:
|
|
DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
|
DummyTestClass.__qualname__ = DummyTestClass.__name__
|
|
|
|
for name in dir(cls):
|
|
if name.startswith("test_"):
|
|
fn = getattr(cls, name)
|
|
if not callable(fn):
|
|
setattr(DummyTestClass, name, getattr(cls, name))
|
|
continue
|
|
new_name = f"{name}{fn_suffix}"
|
|
new_fn = _make_fn_with_patches(fn, *patches)
|
|
new_fn.__name__ = new_name
|
|
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
|
new_fn = unittest.expectedFailure(new_fn)
|
|
setattr(DummyTestClass, new_name, decorator(new_fn))
|
|
# NB: Doesn't handle slots correctly, but whatever
|
|
elif not hasattr(DummyTestClass, name):
|
|
setattr(DummyTestClass, name, getattr(cls, name))
|
|
|
|
return DummyTestClass
|
|
|
|
|
|
# test Python 3.11+ specific features
|
|
def skipIfNotPy311(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
if sys.version_info >= (3, 11):
|
|
return fn
|
|
return unittest.skip(fn)
|
|
|
|
|
|
def skipIfNotPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
if sys.version_info >= (3, 12):
|
|
return fn
|
|
return unittest.skip("Requires Python 3.12+")(fn)
|
|
|
|
|
|
def xfailIfPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
if sys.version_info >= (3, 12):
|
|
return unittest.expectedFailure(fn)
|
|
return fn
|
|
|
|
|
|
def skipIfPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
if sys.version_info >= (3, 12):
|
|
return unittest.skip("Not supported in Python 3.12+")(fn)
|
|
return fn
|
|
|
|
|
|
def requiresPy310(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
if sys.version_info >= (3, 10):
|
|
return fn
|
|
else:
|
|
return unittest.skip("Requires Python 3.10+")(fn)
|
|
|
|
|
|
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
|
|
# and test/dynamo/test_dynamic_shapes.py
|
|
def expectedFailureDynamic(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
fn._expected_failure_dynamic = True # type: ignore[attr-defined]
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
|
|
def expectedFailureCodegenDynamic(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
fn._expected_failure_codegen_dynamic = True # type: ignore[attr-defined]
|
|
return fn
|
|
|
|
|
|
# Controls test generated in test/inductor/test_cpp_wrapper.py
|
|
def expectedFailureDynamicWrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
fn._expected_failure_dynamic_wrapper = True # type: ignore[attr-defined]
|
|
return fn
|
|
|
|
|
|
def reset_rng_state(use_xla: bool = False) -> None:
|
|
torch.manual_seed(1337)
|
|
random.seed(1337)
|
|
if np:
|
|
np.random.seed(1337)
|
|
if use_xla:
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
xm.set_rng_state(1337, str(xm.xla_device()))
|