mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This is a prototype for running extern fallback kernels with a host side proxy executor. Sample of generated cpp wrapper call: ``` at::Tensor buf0; // output buffer void* tensor_args_var_0[] = {&arg0_1, &arg0_1, &arg1_1, &arg0_1, &arg1_1, &buf0}; int64_t int_args_var_1[] = {81, 81, 7, 7, 7, 81}; proxy_executor->call_function("buf0", int_args_var_1, tensor_args_var_0); ``` - In my current implementation, proxy executor interprets the raw pointers according to the ops schema. This assumes that custom op MUST have a valid schema registered to Dispatcher. (I would like to validate this assumption) - I am using callboxed() API of the custom kernels. This is inevitable, as we wish to have a single call_function API for all possible custom kernels. - These are all the input argument types I have support so far. union Argument { # Bool value does not matter 1: bool asNone; 2: TensorArgument asTensor; 3: list<TensorArgument> asTensors; 5: i64 asInt; 7: list<i64> asInts; 8: double asFloat; 9: list<double> asFloats; 10: string asString; 10.5: list<string> asStrings; 11: SymIntArgument asSymInt; 12: list<SymIntArgument> asSymInts; 13: ScalarType asScalarType; 14: MemoryFormat asMemoryFormat; 15: Layout asLayout; 16: Device asDevice; 17: bool asBool; 18: list<bool> asBools; } - Need a policy for handling unpopulated argument with default values. Here are the options, and it has BC implications. 1. requires exported fx graph to explicitly populate default values, if users doesn't specify. 2. requires cpp wrapper to explicitly populate default values, if fx graph doesn't specify. 3. Proxy executor look up from opSchema for default values. For fixing T162112344 Test Plan: frontend: buck2 run mode/dev-sand mode/inplace -c fbcode.enable_gpu_sections=True sigmoid/frontend:export_main test: buck2 run mode/dev-sand //deeplearning/aot_inductor/test:test_custom_ops backend: buck2 run mode/dev-nosan //deeplearning/aot_inductor/fb:main buck2 test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_cmf30x (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)' Reviewed By: suo Differential Revision: D48747417 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108350 Approved by: https://github.com/izaitsevfb
1299 lines
44 KiB
Python
1299 lines
44 KiB
Python
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import sys
|
|
import warnings
|
|
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
|
|
from unittest import mock
|
|
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
|
|
import torch._functorch.config as functorch_config
|
|
|
|
import torch.fx
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo import (
|
|
compiled_autograd,
|
|
logging as dynamo_logging,
|
|
utils as dynamo_utils,
|
|
)
|
|
from torch._dynamo.utils import detect_fake_mode
|
|
from torch._functorch.aot_autograd import make_boxed_func
|
|
from torch._inductor.codecache import code_hash, CompiledFxGraph
|
|
|
|
from torch._inductor.debug import save_args_for_compile_fx_inner
|
|
from torch._ops import OpOverload
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
|
|
|
from .._dynamo.backends.common import aot_autograd
|
|
from ..fx.graph import _PyTreeCodeGen
|
|
from . import config, metrics
|
|
from .debug import DebugContext
|
|
from .decomposition import select_decomp_table
|
|
from .fx_passes.joint_graph import joint_graph_passes
|
|
from .fx_passes.post_grad import post_grad_passes, view_to_reshape
|
|
from .fx_passes.pre_grad import pre_grad_passes
|
|
from .graph import GraphLowering
|
|
from .ir import ExternKernelNode
|
|
from .pattern_matcher import clone_graph
|
|
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
|
|
from .virtualized import V
|
|
|
|
if config.is_fbcode():
|
|
from torch._inductor.fb.utils import time_and_log # type: ignore[import]
|
|
else:
|
|
# no-op decorator
|
|
def time_and_log(attr: str):
|
|
def wrap(old_func):
|
|
@wraps(old_func)
|
|
def newFunction(*args, **kwargs):
|
|
return old_func(*args, **kwargs)
|
|
|
|
return newFunction
|
|
|
|
return wrap
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
|
ALIGNMENT = 16
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BoxedBool:
|
|
value: bool
|
|
|
|
def __bool__(self):
|
|
return self.value
|
|
|
|
@staticmethod
|
|
def disable(obj):
|
|
if isinstance(obj, BoxedBool):
|
|
obj.value = False
|
|
return obj
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BoxedDeviceIndex:
|
|
value: Optional[int]
|
|
|
|
def set(self, device_idx):
|
|
assert device_idx is None or isinstance(device_idx, int)
|
|
self.value = device_idx
|
|
|
|
|
|
# copy_ fails when trying to write to tensors with memory overlap,
|
|
# for expanded dimensions (a dimension which used to have size 1 -> ?)
|
|
# we can select one element from that dimension and write to it
|
|
# to achieve writing to all values of that dimension of the input tensor
|
|
def get_expanded_dims(t):
|
|
if not isinstance(t, torch.Tensor):
|
|
return None
|
|
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
|
|
|
|
|
|
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
|
|
for expanded_dim in expanded_dims:
|
|
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
|
|
return t
|
|
|
|
|
|
def complex_memory_overlap(t: torch.Tensor) -> bool:
|
|
# if torch._debug_has_internal_overlap thinks this tensor potentially has
|
|
# memory overlap internally, let's dig deeper to find out whether it's true.
|
|
t = index_expanded_dims(t, get_expanded_dims(t))
|
|
if torch._debug_has_internal_overlap(t) != 0:
|
|
strides = t.stride()
|
|
sizes = t.shape
|
|
indices = list(range(len(strides)))
|
|
indices = [x for _, x in sorted(zip(strides, indices))]
|
|
for i in range(len(strides)):
|
|
prev_stride = 1 if i == 0 else strides[indices[i - 1]]
|
|
prev_size = 1 if i == 0 else sizes[indices[i - 1]]
|
|
if strides[indices[i]] < prev_stride * prev_size:
|
|
return True
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _step_logger():
|
|
return dynamo_logging.get_step_logger(log)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _warn_tf32_disabled():
|
|
if (
|
|
torch.cuda.is_available()
|
|
and not torch.backends.cuda.matmul.allow_tf32
|
|
and torch.cuda.get_device_capability() >= (8, 0)
|
|
):
|
|
warnings.warn(
|
|
"TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
|
|
"Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
|
|
)
|
|
|
|
|
|
def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
|
|
aten = torch.ops.aten
|
|
tf32_ops = {
|
|
aten.mm.default,
|
|
aten.addmm.default,
|
|
aten.bmm.default,
|
|
aten.baddbmm.default,
|
|
}
|
|
for node in gm.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target in tf32_ops
|
|
and isinstance(node.meta.get("val", None), torch.Tensor)
|
|
and node.meta["val"].dtype == torch.float32
|
|
and node.meta["val"].device.type == "cuda"
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
@DebugContext.wrap
|
|
def count_bytes_inner(
|
|
gm: torch.fx.GraphModule,
|
|
example_inputs: List[torch.Tensor],
|
|
num_fixed: int = 0,
|
|
**kwargs,
|
|
):
|
|
shape_env = _shape_env_from_inputs(example_inputs)
|
|
fake_mode = fake_tensor_prop(gm, example_inputs)
|
|
|
|
with V.set_fake_mode(fake_mode):
|
|
post_grad_passes(gm, False)
|
|
|
|
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
|
|
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs): # type: ignore[call-arg]
|
|
graph.run(*example_inputs)
|
|
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
|
metrics.num_bytes_accessed += num_bytes
|
|
metrics.nodes_num_elem += nodes_num_elem
|
|
metrics.node_runtimes += node_runtimes
|
|
return make_boxed_func(gm.forward)
|
|
|
|
|
|
def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
|
|
@functools.wraps(inner_compile)
|
|
def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
|
|
"""
|
|
Compile into cpp wrapper:
|
|
For CPU, this is currently done in one pass.
|
|
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
|
|
and run it to generate autotuned kernel binaries in the first pass; and then generate
|
|
cpp wrapper code and compile it to a dynamic library in the second pass.
|
|
"""
|
|
devices = (
|
|
{t.device.type for t in gm.parameters()}
|
|
| {t.device.type for t in gm.buffers()}
|
|
| {t.device.type for t in example_inputs if isinstance(t, torch.Tensor)}
|
|
)
|
|
|
|
if "cuda" not in devices:
|
|
kwargs_patched = {**kwargs, "cpp_wrapper": True}
|
|
return inner_compile(gm, example_inputs, **kwargs_patched)
|
|
else:
|
|
with config.patch( # type: ignore[attr-defined]
|
|
{
|
|
"triton.store_cubin": True,
|
|
}
|
|
):
|
|
# first pass with regular python wrapper code
|
|
kwargs_patched = {
|
|
**kwargs,
|
|
"cpp_wrapper": False,
|
|
}
|
|
# clone_graph(gm) makes sure no graph modification from the first pass will
|
|
# leak to the second pass. It does increase memory pressure, but the problem
|
|
# can be alleviated once we have parameters as FakeTensor.
|
|
|
|
compiled = inner_compile(
|
|
clone_graph(gm), example_inputs, **kwargs_patched
|
|
)
|
|
|
|
def materialize(x):
|
|
if isinstance(x, (torch.SymInt, torch.SymFloat)):
|
|
# Need concrete value to run dynamic shapes and tune the result
|
|
return x.node.hint
|
|
else:
|
|
assert not isinstance(x, FakeTensor)
|
|
return x
|
|
|
|
tracing_context = torch._guards.TracingContext.get()
|
|
if tracing_context:
|
|
if tracing_context.output_strides:
|
|
tracing_context.output_strides.clear()
|
|
|
|
params_flat = [
|
|
param
|
|
for param in tracing_context.params_flat # type: ignore[union-attr]
|
|
if param is not None
|
|
]
|
|
real_inputs = [
|
|
materialize(x) for x in (params_flat + V.real_inputs)
|
|
]
|
|
else:
|
|
real_inputs = [materialize(x) for x in V.real_inputs]
|
|
|
|
with torch.utils._python_dispatch._disable_current_modes():
|
|
compiled(real_inputs)
|
|
|
|
del real_inputs
|
|
|
|
# second pass
|
|
kwargs_patched = {**kwargs, "cpp_wrapper": True}
|
|
return inner_compile(gm, example_inputs, **kwargs_patched)
|
|
|
|
return wrapper
|
|
|
|
|
|
def fake_tensor_prop(
|
|
gm: torch.fx.GraphModule,
|
|
example_inputs: List[torch.Tensor],
|
|
force_allow_non_fake_inputs: bool = False,
|
|
):
|
|
"""
|
|
If we can not detect fake mode from the context of inputs, create one.
|
|
|
|
The created fake mode will be returned.
|
|
"""
|
|
fake_mode = detect_fake_mode(example_inputs)
|
|
if not fake_mode:
|
|
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
|
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
|
|
else:
|
|
ctx = (
|
|
contextlib.nullcontext()
|
|
if not force_allow_non_fake_inputs
|
|
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
|
|
)
|
|
with ctx: # type: ignore[attr-defined]
|
|
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
|
|
*example_inputs
|
|
)
|
|
|
|
return fake_mode
|
|
|
|
|
|
@DebugContext.wrap
|
|
@torch.utils._python_dispatch._disable_current_modes()
|
|
@time_and_log(attr="compilation time (in seconds)")
|
|
def compile_fx_inner(
|
|
gm: torch.fx.GraphModule,
|
|
example_inputs: List[torch.Tensor],
|
|
cudagraphs: Optional[BoxedBool] = None,
|
|
num_fixed: int = 0,
|
|
is_backward: bool = False,
|
|
graph_id: Optional[int] = None,
|
|
cpp_wrapper: bool = False,
|
|
aot_mode: bool = False,
|
|
is_inference: bool = False,
|
|
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
|
user_visible_outputs: FrozenSet[str] = frozenset(),
|
|
layout_opt: Optional[bool] = None,
|
|
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
|
):
|
|
"""
|
|
Inductor API that compiles a single graph.
|
|
|
|
If you change the argument list for this funtion, make sure you
|
|
also update the call to save_args_for_compile_fx_inner below accordingly.
|
|
"""
|
|
if dynamo_utils.count_calls(gm.graph) == 0:
|
|
return make_boxed_func(gm.forward)
|
|
|
|
if config.save_args:
|
|
save_args_for_compile_fx_inner(
|
|
gm,
|
|
example_inputs,
|
|
cudagraphs=cudagraphs,
|
|
num_fixed=num_fixed,
|
|
is_backward=is_backward,
|
|
graph_id=graph_id,
|
|
cpp_wrapper=cpp_wrapper,
|
|
aot_mode=aot_mode,
|
|
is_inference=is_inference,
|
|
boxed_forward_device_index=boxed_forward_device_index,
|
|
user_visible_outputs=user_visible_outputs,
|
|
layout_opt=layout_opt,
|
|
)
|
|
|
|
if cudagraphs is None:
|
|
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
|
|
|
# Inputs to fx_codegen_and_compile
|
|
graph_args = [gm, example_inputs]
|
|
graph_kwargs = {
|
|
"cudagraphs": cudagraphs,
|
|
"num_fixed": num_fixed,
|
|
"is_backward": is_backward,
|
|
"graph_id": graph_id,
|
|
"cpp_wrapper": cpp_wrapper,
|
|
"aot_mode": aot_mode,
|
|
"is_inference": is_inference,
|
|
"user_visible_outputs": user_visible_outputs,
|
|
"layout_opt": layout_opt,
|
|
"extern_node_serializer": extern_node_serializer,
|
|
}
|
|
|
|
compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
|
|
*graph_args, **graph_kwargs # type: ignore[arg-type]
|
|
)
|
|
|
|
if aot_mode:
|
|
return compiled_graph
|
|
|
|
if cudagraphs:
|
|
# output args are tuple of first argument
|
|
output = list(gm.graph.nodes)[-1]
|
|
assert len(output.args) == 1
|
|
stack_traces = [
|
|
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
|
for arg in output.args[0]
|
|
]
|
|
|
|
complex_memory_overlap_inputs = any(
|
|
complex_memory_overlap(t)
|
|
for t in example_inputs
|
|
if isinstance(t, torch.Tensor)
|
|
)
|
|
|
|
# doesnt work for non-trees because the warmup run would apply mutation twice
|
|
if config.triton.cudagraph_trees:
|
|
# checking if mutation is only on paramameters/static inputs
|
|
has_mutation = not all(
|
|
idx < num_fixed for idx in compiled_graph.mutated_input_idxs
|
|
)
|
|
else:
|
|
has_mutation = len(compiled_graph.mutated_inputs) != 0
|
|
|
|
cudagraph_tests = [
|
|
(set(compiled_graph.device_types) == {"cuda"}, "non-cuda device in graph"),
|
|
(not has_mutation, "mutated inputs"),
|
|
(not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
|
|
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
|
(
|
|
all(
|
|
isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs
|
|
),
|
|
"non-Tensor inputs",
|
|
),
|
|
(
|
|
(
|
|
len(compiled_graph.device_idxs) == 1
|
|
or not config.triton.cudagraph_trees
|
|
),
|
|
"multiple device indices without cudagraph_trees",
|
|
),
|
|
]
|
|
cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
|
|
|
|
if not cudagraph_fail_reasons:
|
|
if not config.triton.cudagraph_trees:
|
|
# Force specialize all inputs so that CUDA graphs will work
|
|
for t in example_inputs:
|
|
if isinstance(t, torch.SymInt):
|
|
int(t) # guard
|
|
|
|
if (
|
|
boxed_forward_device_index is not None
|
|
and not is_inference
|
|
and not is_backward
|
|
):
|
|
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
|
|
|
compiled_graph.current_callable = cudagraphify(
|
|
compiled_graph.get_current_callable(),
|
|
example_inputs,
|
|
static_input_idxs=range(num_fixed),
|
|
device_index=next(iter(compiled_graph.device_idxs)),
|
|
stack_traces=stack_traces,
|
|
is_backward=is_backward,
|
|
is_inference=is_inference,
|
|
)
|
|
else:
|
|
BoxedBool.disable(cudagraphs)
|
|
|
|
# See [Backward Generation Handling]
|
|
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
|
# know we are we running the backward even if we will not run it in cudagraphs
|
|
if is_backward and config.triton.cudagraph_trees:
|
|
assert boxed_forward_device_index is not None
|
|
assert boxed_forward_device_index.value is not None
|
|
compiled_graph_callable = compiled_graph.get_current_callable()
|
|
|
|
manager = torch._inductor.cudagraph_trees.get_manager(
|
|
boxed_forward_device_index.value, create_if_none_exists=False
|
|
)
|
|
# should already exist from forward
|
|
assert manager is not None
|
|
|
|
def compiled_artifact(new_inputs):
|
|
manager.set_to_running_backward()
|
|
return compiled_graph_callable(new_inputs)
|
|
|
|
compiled_graph.current_callable = compiled_artifact
|
|
|
|
if len(set(compiled_graph.device_types)) > 1:
|
|
perf_hint_log.warning("skipping cudagraphs due to multiple devices")
|
|
elif set(compiled_graph.device_types) == {"cuda"}:
|
|
if has_mutation:
|
|
perf_hint_log.warning("skipping cudagraphs due to input mutation")
|
|
elif complex_memory_overlap_inputs:
|
|
perf_hint_log.warning(
|
|
"skipping cudagraphs due to complex input striding"
|
|
)
|
|
elif (
|
|
len(compiled_graph.device_idxs) > 1
|
|
and config.triton.cudagraph_trees
|
|
):
|
|
perf_hint_log.warning(
|
|
"skipping cudagraphs due to multiple device indexes"
|
|
)
|
|
else:
|
|
perf_hint_log.warning("skipping cudagraphs for unknown reason")
|
|
else:
|
|
perf_hint_log.warning("skipping cudagraphs for unknown reason")
|
|
|
|
# cudagraphs does its own aligning of inputs
|
|
if not cudagraphs:
|
|
new_callable = align_inputs(
|
|
compiled_graph.get_current_callable(), example_inputs, range(num_fixed)
|
|
)
|
|
if new_callable is not compiled_graph.get_current_callable():
|
|
compiled_graph.current_callable = new_callable
|
|
|
|
_step_logger()(
|
|
logging.INFO,
|
|
"torchinductor done compiling "
|
|
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
|
f"graph {graph_id}",
|
|
)
|
|
|
|
# aot autograd needs to know to pass in inputs as a list
|
|
compiled_graph._boxed_call = True
|
|
return compiled_graph
|
|
|
|
|
|
def fx_codegen_and_compile(
|
|
gm: torch.fx.GraphModule,
|
|
example_inputs: List[torch.Tensor],
|
|
cudagraphs: Optional[BoxedBool] = None,
|
|
num_fixed: int = 0,
|
|
is_backward: bool = False,
|
|
graph_id: Optional[int] = None,
|
|
cpp_wrapper: bool = False,
|
|
aot_mode: bool = False,
|
|
is_inference: bool = False,
|
|
user_visible_outputs: FrozenSet[str] = frozenset(),
|
|
layout_opt: Optional[bool] = None,
|
|
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
|
|
) -> CompiledFxGraph:
|
|
if is_tf32_warning_applicable(gm):
|
|
_warn_tf32_disabled()
|
|
|
|
# lift the maximum depth of the Python interpreter stack
|
|
# to adapt large/deep models
|
|
sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
|
|
|
|
_step_logger()(
|
|
logging.INFO,
|
|
"torchinductor compiling "
|
|
f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
|
|
f"graph {graph_id}",
|
|
)
|
|
V.debug.fx_graph(gm, example_inputs)
|
|
|
|
shape_env = _shape_env_from_inputs(example_inputs)
|
|
|
|
# Convert view to reshape in the graph. This is necessary primarily for
|
|
# layout optimization. Do it unconditionally for uniformity.
|
|
#
|
|
# It's needed because when we do layout optimization, an contiguous tensor
|
|
# in eager mode may becomes a channels last tensor. A view op previously
|
|
# can be applied to the contiguous tensor may not be able to be applied
|
|
# on the channels tensor any more. An error like
|
|
# RuntimeError: view size is not compatible with input tensor's size and stride
|
|
# (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
|
|
# will be printed.
|
|
#
|
|
# Replace view op to reshape op in this case.
|
|
# As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
|
|
#
|
|
# Also this has to be done before FakeTensorProp below to avoid the failed
|
|
# .view() call.
|
|
view_to_reshape(gm)
|
|
|
|
fake_mode = fake_tensor_prop(gm, example_inputs)
|
|
|
|
# pattern matcher passes might not preserve striding information
|
|
# on node.meta["val"]. if in the future we rely on these being
|
|
# correct we will need to fix.
|
|
|
|
with V.set_fake_mode(fake_mode): # type: ignore[call-arg]
|
|
# has some issues with memory in training
|
|
post_grad_passes(gm, is_inference=is_inference)
|
|
V.debug.fx_graph_transformed(gm, example_inputs)
|
|
|
|
with V.set_fake_mode(fake_mode): # type: ignore[call-arg]
|
|
graph = GraphLowering(
|
|
gm,
|
|
shape_env=shape_env,
|
|
num_static_inputs=num_fixed,
|
|
graph_id=graph_id,
|
|
cpp_wrapper=cpp_wrapper,
|
|
aot_mode=aot_mode,
|
|
user_visible_outputs=user_visible_outputs,
|
|
extern_node_serializer=extern_node_serializer,
|
|
)
|
|
with V.set_graph_handler(graph): # type: ignore[call-arg]
|
|
graph.run(*example_inputs)
|
|
context = torch._guards.TracingContext.get()
|
|
if context is not None and context.output_strides is not None:
|
|
# Return the output strides to the caller via TracingContext
|
|
assert len(context.output_strides) == 0
|
|
assert graph.graph_outputs is not None
|
|
for out in graph.graph_outputs:
|
|
if hasattr(out, "layout"):
|
|
context.output_strides.append(
|
|
tuple( # type: ignore[arg-type]
|
|
V.graph.sizevars.size_hint(s) for s in out.layout.stride
|
|
)
|
|
)
|
|
else:
|
|
context.output_strides.append(None)
|
|
compiled_fn = graph.compile_to_fn()
|
|
|
|
if graph.disable_cudagraphs:
|
|
BoxedBool.disable(cudagraphs)
|
|
|
|
compiled_graph = CompiledFxGraph(
|
|
compiled_artifact=compiled_fn,
|
|
cache_key=graph.cache_key,
|
|
artifact_path=graph.cache_path,
|
|
cache_linemap=graph.cache_linemap,
|
|
device_types=graph.device_types,
|
|
device_idxs=graph.device_idxs,
|
|
mutated_inputs=graph.mutated_inputs,
|
|
mutated_input_idxs=set(graph.mutated_input_idxs),
|
|
)
|
|
return compiled_graph
|
|
|
|
|
|
def clone_preserve_strides(x: torch.Tensor):
|
|
needed_size = (
|
|
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
|
|
)
|
|
buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
|
|
return torch.as_strided(buffer, x.size(), x.stride())
|
|
|
|
|
|
def copy_misaligned_inputs(
|
|
new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
|
|
) -> None:
|
|
for i in check_inputs_idxs:
|
|
if new_inputs[i].data_ptr() % ALIGNMENT:
|
|
new_inputs[i] = clone_preserve_strides(new_inputs[i])
|
|
|
|
|
|
def get_input_idxs_to_check(
|
|
inputs: Union[List[torch.Tensor], Sequence[int]],
|
|
static_input_idxs: Sequence[int],
|
|
) -> Sequence[int]:
|
|
def is_aligned(storage_offset, dtype):
|
|
return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
|
|
|
|
ids_to_check = []
|
|
for i, input in enumerate(inputs):
|
|
if (
|
|
isinstance(input, torch.Tensor)
|
|
and (
|
|
i not in static_input_idxs
|
|
or not is_aligned(input.storage_offset(), input.dtype)
|
|
)
|
|
and input.device.type == "cuda"
|
|
):
|
|
ids_to_check.append(i)
|
|
return ids_to_check
|
|
|
|
|
|
def align_inputs_from_check_idxs(
|
|
model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
|
|
):
|
|
if len(inputs_to_check) == 0:
|
|
return model
|
|
|
|
def run(new_inputs):
|
|
copy_misaligned_inputs(new_inputs, inputs_to_check)
|
|
return model(new_inputs)
|
|
|
|
return run
|
|
|
|
|
|
def align_inputs(
|
|
model: Callable[[List[torch.Tensor]], Any],
|
|
inputs: List[torch.Tensor],
|
|
static_input_idxs: Sequence[int] = (),
|
|
):
|
|
inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs)
|
|
return align_inputs_from_check_idxs(model, inputs_to_check)
|
|
|
|
|
|
@dynamo_utils.dynamo_timed
|
|
def cudagraphify(
|
|
model: torch.fx.GraphModule,
|
|
inputs: List[torch.Tensor],
|
|
static_input_idxs: Sequence[int] = (),
|
|
*,
|
|
device_index: int,
|
|
stack_traces: List[Optional[str]],
|
|
is_backward: bool,
|
|
is_inference: bool,
|
|
):
|
|
from torch._inductor.cudagraph_trees import (
|
|
cudagraphify_impl as new_cudagraphify_impl,
|
|
)
|
|
|
|
cudagraphify_fn: Callable[..., Any]
|
|
if config.triton.cudagraph_trees:
|
|
cudagraphify_fn = functools.partial(
|
|
new_cudagraphify_impl,
|
|
device_index=device_index,
|
|
stack_traces=stack_traces,
|
|
is_backward=is_backward,
|
|
is_inference=is_inference,
|
|
)
|
|
else:
|
|
cudagraphify_fn = cudagraphify_impl
|
|
|
|
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
|
|
if not any(isinstance(inp, FakeTensor) for inp in inputs):
|
|
return cudagraphify_fn(model, inputs, static_input_idxs)
|
|
|
|
compiled_fn = None
|
|
|
|
def run(new_inputs):
|
|
nonlocal compiled_fn
|
|
if compiled_fn is None:
|
|
with dynamo_utils.preserve_rng_state():
|
|
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
|
return compiled_fn(new_inputs)
|
|
|
|
return run
|
|
|
|
|
|
def remove_unaligned_input_idxs(
|
|
inputs: Union[List[torch.Tensor], Sequence[int]],
|
|
static_input_idxs: Sequence[int],
|
|
):
|
|
"""
|
|
We require all inputs to be aligned, so introduce a copy for any
|
|
that aren't.
|
|
"""
|
|
aligned_static_input_idxs = []
|
|
for idx, input in zip(static_input_idxs, inputs):
|
|
if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
|
|
aligned_static_input_idxs.append(idx)
|
|
if len(aligned_static_input_idxs) != len(static_input_idxs):
|
|
return aligned_static_input_idxs
|
|
return static_input_idxs
|
|
|
|
|
|
def static_input(x: torch.Tensor):
|
|
"""
|
|
Copy and input while preserving strides
|
|
"""
|
|
# TODO(jansel): figure out why this version doesn't work:
|
|
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
|
|
needed_size = (
|
|
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
|
|
)
|
|
buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
|
|
return torch.as_strided(buffer, x.size(), x.stride())
|
|
|
|
|
|
def index_expanded_dims_and_copy_(
|
|
dst: torch.Tensor,
|
|
src: torch.Tensor,
|
|
expanded_dims: List[int],
|
|
):
|
|
"Index into expanded dimensions of both dst and src then copy_"
|
|
dst = index_expanded_dims(dst, expanded_dims)
|
|
src = index_expanded_dims(src, expanded_dims)
|
|
dst.copy_(src)
|
|
|
|
|
|
def cudagraphify_impl(
|
|
model: torch.fx.GraphModule,
|
|
inputs: List[torch.Tensor],
|
|
static_input_idxs: Sequence[int] = (),
|
|
):
|
|
"""
|
|
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
|
"""
|
|
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
|
|
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
|
|
copy_misaligned_inputs(inputs, check_input_idxs)
|
|
|
|
assert isinstance(inputs, list)
|
|
|
|
inps_expanded_dims = [
|
|
get_expanded_dims(x) if idx not in static_input_idxs else []
|
|
for idx, x in enumerate(inputs)
|
|
]
|
|
|
|
# allocate static tensor inputs
|
|
static_inputs = [
|
|
x
|
|
if not isinstance(x, torch.Tensor)
|
|
else static_input(x)
|
|
if idx not in static_input_idxs
|
|
else x.detach()
|
|
for idx, x in enumerate(inputs)
|
|
]
|
|
|
|
# copy over input values for fresh allocations
|
|
for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
|
|
if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
|
|
index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
|
|
|
|
# warmup
|
|
torch.cuda.synchronize()
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
# copy static_inputs because it will be cleared in model
|
|
with torch.cuda.stream(stream):
|
|
model(list(static_inputs))
|
|
stream.synchronize()
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# record
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
|
|
static_outputs = model(list(static_inputs))
|
|
if not isinstance(static_outputs, (list, tuple)):
|
|
static_outputs = (static_outputs,)
|
|
|
|
if config.size_asserts:
|
|
|
|
def run(new_inputs):
|
|
assert len(static_inputs) == len(new_inputs)
|
|
for idx, (dst, src, expanded_dims) in enumerate(
|
|
zip(static_inputs, new_inputs, inps_expanded_dims)
|
|
):
|
|
if not isinstance(dst, torch.Tensor):
|
|
pass
|
|
elif idx in static_input_idxs:
|
|
assert dst.data_ptr() == src.data_ptr()
|
|
else:
|
|
# TODO - could make one single op of multiple slices
|
|
# and avoid dispatch.
|
|
# Could also pre-index the `dst` tensors
|
|
index_expanded_dims_and_copy_(dst, src, expanded_dims)
|
|
new_inputs.clear()
|
|
graph.replay()
|
|
return static_outputs
|
|
|
|
else:
|
|
copy_indices = [
|
|
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
|
]
|
|
|
|
def run(new_inputs):
|
|
for idx in copy_indices:
|
|
expanded_dims = inps_expanded_dims[idx]
|
|
index_expanded_dims_and_copy_(
|
|
static_inputs[idx], new_inputs[idx], expanded_dims
|
|
)
|
|
new_inputs.clear()
|
|
graph.replay()
|
|
return static_outputs
|
|
|
|
return align_inputs_from_check_idxs(run, check_input_idxs)
|
|
|
|
|
|
def count_tangents(fx_g: torch.fx.GraphModule):
|
|
"""
|
|
Infers which inputs are static for a backwards graph
|
|
"""
|
|
|
|
def is_saved_tensor(x):
|
|
return (
|
|
"tangents" not in x.name
|
|
and "bwd_seed" not in x.name
|
|
and "bwd_base_offset" not in x.name
|
|
)
|
|
|
|
arg_count = 0
|
|
static_arg_idxs = []
|
|
for n in fx_g.graph.nodes:
|
|
if n.op == "placeholder":
|
|
if is_saved_tensor(n):
|
|
static_arg_idxs.append(arg_count)
|
|
arg_count += 1
|
|
|
|
assert static_arg_idxs == list(range(len(static_arg_idxs)))
|
|
return len(static_arg_idxs)
|
|
|
|
|
|
_in_aot_compilation = BoxedBool(False)
|
|
|
|
|
|
def compile_fx_aot(
|
|
model_: torch.fx.GraphModule,
|
|
example_inputs_: List[torch.Tensor],
|
|
inner_compile: Callable[..., Any] = compile_fx_inner,
|
|
config_patches: Optional[Dict[str, Any]] = None,
|
|
):
|
|
config_patches = (
|
|
{"cpp_wrapper": True}
|
|
if config_patches is None
|
|
else {**config_patches, "cpp_wrapper": True}
|
|
)
|
|
if (
|
|
"aot_inductor_output_path" not in config_patches
|
|
and not config.aot_inductor_output_path
|
|
):
|
|
config_patches = {
|
|
**config_patches,
|
|
"aot_inductor_output_path": code_hash(model_.code),
|
|
}
|
|
|
|
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
|
with mock.patch.object(_in_aot_compilation, "value", True):
|
|
return compile_fx(
|
|
model_,
|
|
example_inputs_,
|
|
inner_compile=functools.partial(
|
|
inner_compile,
|
|
aot_mode=True,
|
|
extern_node_serializer=extern_node_serializer,
|
|
),
|
|
config_patches=config_patches,
|
|
)
|
|
|
|
|
|
_graph_counter = itertools.count(0)
|
|
|
|
|
|
def fw_compiler_freezing(
|
|
aot_autograd_model: torch.fx.GraphModule,
|
|
aot_example_inputs: List[torch.Tensor],
|
|
dynamo_model: torch.fx.GraphModule,
|
|
num_example_inputs: int,
|
|
inner_compile: Callable[..., Any],
|
|
cudagraphs: BoxedBool,
|
|
graph_id: int,
|
|
forward_device: BoxedDeviceIndex,
|
|
):
|
|
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
|
|
|
|
# partition_fn won't be called
|
|
joint_graph_passes(aot_autograd_model)
|
|
|
|
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model)
|
|
if layout_opt:
|
|
# make sure meta['val'] is properly setup
|
|
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
|
|
convert_conv_weights_to_channels_last(aot_autograd_model)
|
|
|
|
opt_model, preserved_arg_indices = freeze(
|
|
dynamo_model,
|
|
aot_autograd_model,
|
|
aot_example_inputs, # type: ignore[arg-type]
|
|
)
|
|
|
|
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
|
|
num_fixed = len(preserved_arg_indices) - num_example_inputs
|
|
|
|
fake_mode = detect_fake_mode(aot_example_inputs)
|
|
|
|
# for freezing, all graph outputs should be user visible
|
|
*_, model_outputs_node = opt_model.graph.nodes
|
|
model_outputs = model_outputs_node.args[0]
|
|
user_visible_outputs = [
|
|
n.name for n in model_outputs if isinstance(n, torch.fx.Node)
|
|
]
|
|
|
|
# constant params will be real tensors, not fake
|
|
tracing_context = torch._guards.TracingContext.get()
|
|
assert tracing_context is not None
|
|
params_flat = tracing_context.params_flat
|
|
assert params_flat is not None
|
|
for i in range(len(params_flat)):
|
|
if i not in preserved_arg_indices:
|
|
params_flat[i] = None
|
|
|
|
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
|
|
optimized_function = inner_compile(
|
|
opt_model,
|
|
aot_example_inputs,
|
|
num_fixed=num_fixed,
|
|
cudagraphs=cudagraphs,
|
|
graph_id=graph_id,
|
|
is_inference=True,
|
|
boxed_forward_device_index=forward_device,
|
|
layout_opt=layout_opt,
|
|
user_visible_outputs=user_visible_outputs,
|
|
)
|
|
|
|
# aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
|
|
# that drops constant-ified params
|
|
if _in_aot_compilation:
|
|
return optimized_function
|
|
|
|
def wrapper(args):
|
|
args_new = [args[i] for i in preserved_arg_indices]
|
|
args.clear()
|
|
return optimized_function(args_new)
|
|
|
|
wrapper._boxed_call = True # type: ignore[attr-defined]
|
|
|
|
return wrapper
|
|
|
|
|
|
def compile_fx(
|
|
model_: torch.fx.GraphModule,
|
|
example_inputs_: List[torch.Tensor],
|
|
inner_compile: Callable[..., Any] = compile_fx_inner,
|
|
config_patches: Optional[Dict[str, Any]] = None,
|
|
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
|
|
):
|
|
"""Main entrypoint to a compile given FX graph"""
|
|
if config_patches:
|
|
with config.patch(config_patches): # type: ignore[attr-defined]
|
|
return compile_fx(
|
|
model_,
|
|
example_inputs_,
|
|
# need extra layer of patching as backwards is compiled out of scope
|
|
inner_compile=config.patch(config_patches)(inner_compile), # type: ignore[attr-defined]
|
|
decompositions=decompositions,
|
|
)
|
|
|
|
if config.cpp_wrapper:
|
|
with config.patch( # type: ignore[attr-defined]
|
|
{
|
|
"cpp_wrapper": False,
|
|
"triton.autotune_cublasLt": False,
|
|
"triton.cudagraphs": False,
|
|
# CudaWrapperCodeGen relies on kernel name to find the autotuned cubin file
|
|
"triton.unique_kernel_names": True,
|
|
}
|
|
), V.set_real_inputs(
|
|
example_inputs_
|
|
): # type: ignore[call-arg]
|
|
return compile_fx(
|
|
model_,
|
|
example_inputs_,
|
|
inner_compile=inner_compile_with_cpp_wrapper(inner_compile),
|
|
decompositions=decompositions,
|
|
)
|
|
|
|
recursive_compile_fx = functools.partial(
|
|
compile_fx,
|
|
inner_compile=inner_compile,
|
|
decompositions=decompositions,
|
|
)
|
|
|
|
if not graph_returns_tuple(model_):
|
|
return make_graph_return_tuple(
|
|
model_,
|
|
example_inputs_,
|
|
recursive_compile_fx,
|
|
)
|
|
|
|
if isinstance(model_, torch.fx.GraphModule):
|
|
if isinstance(model_.graph._codegen, _PyTreeCodeGen):
|
|
# this graph is the result of dynamo.export()
|
|
return handle_dynamo_export_graph(
|
|
model_,
|
|
example_inputs_,
|
|
recursive_compile_fx,
|
|
)
|
|
|
|
# Since handle_dynamo_export_graph will trigger compile_fx again,
|
|
# Move these passes after handle_dynamo_export_graph to avoid repeated calls.
|
|
model_ = pre_grad_passes(model_, example_inputs_)
|
|
|
|
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
|
|
return flatten_graph_inputs(
|
|
model_,
|
|
example_inputs_,
|
|
recursive_compile_fx,
|
|
)
|
|
|
|
assert not config._raise_error_for_testing
|
|
num_example_inputs = len(example_inputs_)
|
|
cudagraphs = BoxedBool(config.triton.cudagraphs)
|
|
forward_device = BoxedDeviceIndex(None)
|
|
|
|
graph_id = next(_graph_counter)
|
|
|
|
decompositions = (
|
|
decompositions if decompositions is not None else select_decomp_table()
|
|
)
|
|
|
|
@dynamo_utils.dynamo_timed
|
|
def fw_compiler_base(
|
|
model: torch.fx.GraphModule,
|
|
example_inputs: List[torch.Tensor],
|
|
is_inference: bool,
|
|
):
|
|
if is_inference:
|
|
# partition_fn won't be called
|
|
joint_graph_passes(model)
|
|
|
|
num_rng_seed_offset_inputs = 2 if functorch_config.functionalize_rng_ops else 0
|
|
fixed = len(example_inputs) - num_example_inputs - num_rng_seed_offset_inputs
|
|
user_visible_outputs = set()
|
|
|
|
if config.keep_output_stride:
|
|
*_, model_outputs_node = model.graph.nodes
|
|
assert model_outputs_node.op == "output"
|
|
model_outputs, _ = pytree.tree_flatten(model_outputs_node.args)
|
|
num_model_outputs = len(model_outputs)
|
|
|
|
context = torch._guards.TracingContext.get()
|
|
if context is not None and context.fw_metadata:
|
|
original_output_start_index = context.fw_metadata.num_mutated_inputs
|
|
else:
|
|
original_output_start_index = 0
|
|
|
|
if isinstance(model_, torch.fx.GraphModule):
|
|
*_, orig_model_outputs_node = model_.graph.nodes
|
|
assert orig_model_outputs_node.op == "output"
|
|
orig_model_outputs, _ = pytree.tree_flatten(
|
|
orig_model_outputs_node.args
|
|
)
|
|
num_orig_model_outputs = len(orig_model_outputs)
|
|
else:
|
|
num_orig_model_outputs = num_model_outputs
|
|
|
|
assert num_orig_model_outputs <= num_model_outputs
|
|
|
|
# We makes the following assumption
|
|
# For inference
|
|
# len(orig_model_outputs) == len(model_outputs)
|
|
# For training
|
|
# len(orig_model_outputs) <= len(model_outputs)
|
|
# During training, most of the time the model_outputs starts with
|
|
# orignal module's outputs followed by saved activations.
|
|
# But this can be not true if the model have inplace updated tensors.
|
|
# AOTAutograd will make those tensors being returned before the orignal
|
|
# module's output.
|
|
# To make things safe, we'll use original_output_start_index field
|
|
# set by AOTAutograd to decide where the original module outputs start.
|
|
|
|
user_visible_outputs = {
|
|
n.name
|
|
for n in model_outputs[
|
|
original_output_start_index : original_output_start_index
|
|
+ num_orig_model_outputs
|
|
]
|
|
if isinstance(n, torch.fx.Node)
|
|
}
|
|
|
|
return inner_compile(
|
|
model,
|
|
example_inputs,
|
|
num_fixed=fixed,
|
|
cudagraphs=cudagraphs,
|
|
graph_id=graph_id,
|
|
is_inference=is_inference,
|
|
boxed_forward_device_index=forward_device,
|
|
user_visible_outputs=user_visible_outputs,
|
|
)
|
|
|
|
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
|
|
|
|
if config.freezing and not torch.is_grad_enabled():
|
|
inference_compiler = functools.partial(
|
|
fw_compiler_freezing,
|
|
dynamo_model=model_,
|
|
num_example_inputs=num_example_inputs,
|
|
inner_compile=inner_compile,
|
|
cudagraphs=cudagraphs,
|
|
graph_id=graph_id,
|
|
forward_device=forward_device,
|
|
)
|
|
else:
|
|
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
|
|
|
def partition_fn(graph, joint_inputs, **kwargs):
|
|
joint_graph_passes(graph)
|
|
return min_cut_rematerialization_partition(
|
|
graph, joint_inputs, **kwargs, compiler="inductor"
|
|
)
|
|
|
|
@dynamo_utils.dynamo_timed
|
|
def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
|
fixed = count_tangents(model)
|
|
return inner_compile(
|
|
model,
|
|
example_inputs,
|
|
num_fixed=fixed,
|
|
cudagraphs=cudagraphs,
|
|
is_backward=True,
|
|
graph_id=graph_id,
|
|
boxed_forward_device_index=forward_device,
|
|
)
|
|
|
|
# TODO: can add logging before/after the call to create_aot_dispatcher_function
|
|
# in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
|
|
# once torchdynamo is merged into pytorch
|
|
fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode(
|
|
allow_non_fake_inputs=True
|
|
)
|
|
tracing_context = (
|
|
torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode)
|
|
)
|
|
|
|
with V.set_fake_mode(fake_mode), torch._guards.tracing( # type: ignore[call-arg]
|
|
tracing_context
|
|
), compiled_autograd.disable():
|
|
return aot_autograd(
|
|
fw_compiler=fw_compiler,
|
|
bw_compiler=bw_compiler,
|
|
inference_compiler=inference_compiler,
|
|
decompositions=decompositions,
|
|
partition_fn=partition_fn,
|
|
keep_inference_input_mutations=True,
|
|
)(model_, example_inputs_)
|
|
|
|
|
|
# pass config dict back to user
|
|
def get_patched_config_dict(config_patches=None):
|
|
with config.patch(config_patches): # type: ignore[attr-defined]
|
|
return config.get_config_copy() # type: ignore[attr-defined]
|
|
|
|
|
|
def _shape_env_from_inputs(inputs: List[torch.Tensor]):
|
|
shape_env = None
|
|
fake_mode = detect_fake_mode(inputs)
|
|
|
|
# TODO(voz): It would be nice to enable this assert, but there are lots of tests that
|
|
# pass in real inputs for now.
|
|
# if len(inputs) > 0:
|
|
# assert fake_mode is not None, breakpoint()
|
|
|
|
if fake_mode is not None:
|
|
return fake_mode.shape_env
|
|
|
|
# When there are no tensor inputs, get shape_env from the first SymInt.
|
|
for input in inputs:
|
|
if isinstance(input, torch.SymInt):
|
|
return input.node.shape_env
|
|
|
|
# TODO(voz): Should we always have one anyway?
|
|
return None
|
|
|
|
|
|
def output_node(gm: torch.fx.GraphModule):
|
|
"""Get the output node from an FX graph"""
|
|
last_node = next(iter(reversed(gm.graph.nodes)))
|
|
assert last_node.op == "output"
|
|
return last_node
|
|
|
|
|
|
def graph_returns_tuple(gm: torch.fx.GraphModule):
|
|
"""True if a FX graph returns a tuple"""
|
|
if not isinstance(gm, torch.fx.GraphModule):
|
|
return True # can't check this, assume true
|
|
(rv,) = output_node(gm).args
|
|
if isinstance(rv, (list, tuple)):
|
|
return True
|
|
if (
|
|
isinstance(rv, torch.fx.node.Node)
|
|
and hasattr(rv.target, "_schema")
|
|
and len(rv.target._schema.returns) > 1
|
|
and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
|
|
):
|
|
# for graphs whose result is one node with multiple outputs
|
|
return True
|
|
return False
|
|
|
|
|
|
def make_graph_return_tuple(
|
|
gm: torch.fx.GraphModule,
|
|
inputs: List[torch.Tensor],
|
|
compile_gm: Callable[..., Any],
|
|
):
|
|
"""
|
|
Mutate gm so it returns a tuple. This is only needed for graphs
|
|
not created by torchdynamo that return non-tuples.
|
|
"""
|
|
node = output_node(gm)
|
|
(rv,) = node.args
|
|
rv, spec = pytree.tree_flatten(rv)
|
|
with gm.graph.inserting_before(node):
|
|
gm.graph.output(rv)
|
|
gm.graph.erase_node(node)
|
|
assert graph_returns_tuple(gm)
|
|
|
|
compiled_fn = compile_gm(gm, inputs)
|
|
|
|
@functools.wraps(compiled_fn)
|
|
def wrapper(*args, **kwargs):
|
|
return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
|
|
|
|
return wrapper
|
|
|
|
|
|
def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
|
|
"""
|
|
Mutate inputs so that they are flat and wrap gm such that it
|
|
accepts those inputs. This is only needed for graphs not created
|
|
by torchdynamo that take bumpy inputs.
|
|
"""
|
|
inputs, spec = pytree.tree_flatten(inputs)
|
|
|
|
class GmWrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gm = gm
|
|
|
|
def forward(self, *args):
|
|
args: List[Any] = list(args)
|
|
return self.gm(*pytree.tree_unflatten(args, spec))
|
|
|
|
compiled_fn = compile_gm(GmWrapper(), inputs)
|
|
|
|
@functools.wraps(compiled_fn)
|
|
def wrapper(*args):
|
|
# note this doesn't check the spec, assuming it is the same
|
|
return compiled_fn(*pytree.tree_flatten(args)[0])
|
|
|
|
return wrapper
|
|
|
|
|
|
def handle_dynamo_export_graph(
|
|
gm: torch.fx.GraphModule,
|
|
inputs: List[torch.Tensor],
|
|
compile_gm: Callable[..., Any],
|
|
):
|
|
"""
|
|
`torch._dynamo.export` embeds pytrees in the FX graph codegen object,
|
|
convert that to a normal FX graph so inductor can compile it.
|
|
"""
|
|
codegen = gm.graph._codegen
|
|
gm.graph._codegen = torch.fx.graph.CodeGen()
|
|
gm.recompile()
|
|
|
|
compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
|
|
|
|
@functools.wraps(compiled_fn)
|
|
def wrapper(*args):
|
|
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
|
|
|
|
return wrapper
|