mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR implements the Autograd feature of the associative_scan. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139939 Approved by: https://github.com/huydhn
1271 lines
48 KiB
Python
1271 lines
48 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
import functools
|
|
from collections.abc import Iterable, Sequence
|
|
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
|
|
|
import torch
|
|
import torch.fx.traceback as fx_traceback
|
|
import torch.utils._pytree as pytree
|
|
from torch._dispatch.python import suspend_functionalization
|
|
from torch._guards import detect_fake_mode
|
|
from torch._higher_order_ops.schema import HopSchema
|
|
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch._subclasses.functional_tensor import (
|
|
disable_functional_mode,
|
|
FunctionalTensor,
|
|
)
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
_temp_remove_metadata_torch_function_mode,
|
|
disable_proxy_modes_tracing,
|
|
make_fx,
|
|
)
|
|
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
|
|
|
|
@dataclass
|
|
class UnsupportedAliasMutationException(RuntimeError):
|
|
reason: str
|
|
|
|
|
|
def autograd_not_implemented_inner(
|
|
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
"""If autograd is enabled and any of the arguments require grad this will either
|
|
raise an error or return a DelayedError depending on the value of delayed.
|
|
|
|
Args:
|
|
operator: The Operator to call with the *args and **kwargs with
|
|
op_name: The name of the Operator
|
|
delayed_error: If True, return a DelayedError instead of raising an error
|
|
args: The flattened operands to the Operator
|
|
kwargs: The keyword arguments to the Operator
|
|
|
|
Raises:
|
|
RuntimeError: If autograd is enabled and any of the arguments to the Operator
|
|
"""
|
|
with torch._C._AutoDispatchBelowAutograd():
|
|
result = operator(*args, **kwargs)
|
|
flat_operands = pytree.arg_tree_leaves(*args)
|
|
if torch.is_grad_enabled() and any(
|
|
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
|
):
|
|
if delayed_error:
|
|
err_fn = torch._C._functions.DelayedError(
|
|
f"Autograd not implemented for {str(operator)}",
|
|
1,
|
|
)
|
|
|
|
def fake_requires_grad(tensor):
|
|
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
|
|
tensor = tensor.detach()
|
|
tensor.requires_grad = True
|
|
return tensor
|
|
|
|
return pytree.tree_map_only(
|
|
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
|
|
return result
|
|
|
|
|
|
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
|
|
def inner(*args, **kwargs):
|
|
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def _maybe_run_with_interpreter(fn):
|
|
maybe_interpreted_fn = fn
|
|
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
|
# Running graph with interpreter is needed for propagating the stack_trace
|
|
def graph_with_interpreter(*args):
|
|
with fx_traceback.preserve_node_meta():
|
|
return torch.fx.Interpreter(fn).run(*args)
|
|
|
|
maybe_interpreted_fn = graph_with_interpreter
|
|
return maybe_interpreted_fn
|
|
|
|
|
|
def _maybe_compile_and_run_fn(fn, *args):
|
|
if not torch.compiler.is_dynamo_compiling():
|
|
from torch._dynamo.backends.debugging import (
|
|
make_eager_backend_with_torch_function_mode,
|
|
)
|
|
|
|
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
|
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
|
if metadata_mode:
|
|
backend: Union[str, Callable[..., Any]] = (
|
|
make_eager_backend_with_torch_function_mode(metadata_mode)
|
|
)
|
|
else:
|
|
backend = "eager"
|
|
return torch.compile(fn, backend=backend, fullgraph=True)(*args)
|
|
else:
|
|
return fn(*args)
|
|
|
|
|
|
def reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
assert _CURRENT_MAKE_FX_TRACER is not None, (
|
|
"Cannot reenter make_fx when we're not under a make_fx tracing session"
|
|
)
|
|
gm = _CURRENT_MAKE_FX_TRACER.trace_subgraph(
|
|
_maybe_run_with_interpreter(fn), *args
|
|
)
|
|
return gm
|
|
|
|
return wrapped
|
|
|
|
|
|
def _maybe_reenter_make_fx(fn):
|
|
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
|
|
if _CURRENT_MAKE_FX_TRACER is not None:
|
|
return reenter_make_fx(fn)
|
|
else:
|
|
|
|
def _maybe_make_fx_with_fake_mode(fn):
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
from torch._guards import detect_fake_mode
|
|
|
|
fake_mode = detect_fake_mode(args)
|
|
if fake_mode is None:
|
|
# we creaeta a fake_mode here to make sure we could
|
|
# trace the graph with data-dependent calls e.g. .item()
|
|
return make_fx(fn, tracing_mode="fake")(*args)
|
|
# Tracing with real if all inputs have been fakfied
|
|
return make_fx(fn)(*args)
|
|
|
|
return wrapped
|
|
|
|
return _maybe_make_fx_with_fake_mode(fn)
|
|
|
|
|
|
def check_meta_consistency(
|
|
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
|
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
|
lhs_name: str,
|
|
rhs_name: str,
|
|
include_contiguity: bool = True,
|
|
) -> None:
|
|
def diff_meta_pairs(
|
|
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
|
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
|
) -> list[str]:
|
|
def diff_meta(
|
|
lhs: Union[torch.Tensor, torch.SymInt, int],
|
|
rhs: Union[torch.Tensor, torch.SymInt, int],
|
|
) -> str:
|
|
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
|
return ", ".join(
|
|
diff_tensor_meta(
|
|
_extract_tensor_metadata(
|
|
lhs, include_contiguity=include_contiguity
|
|
),
|
|
_extract_tensor_metadata(
|
|
rhs, include_contiguity=include_contiguity
|
|
),
|
|
check_grad=False,
|
|
)
|
|
)
|
|
else:
|
|
|
|
def _both_int_types(lhs, rhs):
|
|
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
|
|
rhs, (int, torch.SymInt)
|
|
)
|
|
|
|
def _both_tensor(lhs, rhs):
|
|
return isinstance(lhs, torch.Tensor) and isinstance(
|
|
rhs, torch.Tensor
|
|
)
|
|
|
|
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
|
|
return f"type: {lhs} vs {rhs}"
|
|
|
|
return ""
|
|
|
|
# Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
|
|
def diff_device(
|
|
lhs: Union[torch.Tensor, torch.SymInt, int],
|
|
rhs: Union[torch.Tensor, torch.SymInt, int],
|
|
) -> str:
|
|
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
|
if (
|
|
rhs.device.type == lhs.device.type
|
|
and rhs.device.index == lhs.device.index
|
|
):
|
|
return ""
|
|
else:
|
|
return "device"
|
|
return ""
|
|
|
|
if len(lhs_list) != len(rhs_list):
|
|
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
|
f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
|
|
)
|
|
all_diffs = []
|
|
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
|
|
if diff := diff_meta(lhs, rhs):
|
|
all_diffs.append(
|
|
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
|
)
|
|
if diff := diff_device(lhs, rhs):
|
|
all_diffs.append(
|
|
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
|
)
|
|
return all_diffs
|
|
|
|
if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
|
|
diff_str = "\n".join(all_diffs)
|
|
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
|
f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def _set_compilation_env():
|
|
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
|
_old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs
|
|
_old_capture_scalar_outputs = torch._dynamo.config.capture_scalar_outputs
|
|
# The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds
|
|
# the top-level frame produces no graph, the default behavior is to fallback to eager.
|
|
# Then when it encounters an inner function, it will try to trace that function again, which is unnecessary.
|
|
# For while_loop, during inspecting the inner call, we trace into the python dispathcer
|
|
# logic, which is not tracable as of today. So the proper fix can be either 1. allow dispatch
|
|
# logic to be dynamo tracable or 2. fixing https://github.com/pytorch/pytorch/issues/144360.
|
|
# but it exposes some bugs in existing tests so we have to have a temporary flag to control
|
|
# the behavior, which allows dynamo to store an empty graph for a frame without falling back to eager
|
|
try:
|
|
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
|
# once we are confident fx tracing works with dynamo.
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
|
torch._dynamo.config.allow_empty_graphs = True
|
|
torch._dynamo.config.capture_scalar_outputs = True
|
|
yield
|
|
finally:
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
|
torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
|
|
torch._dynamo.config.capture_scalar_outputs = _old_capture_scalar_outputs
|
|
|
|
|
|
# The invariant here is that we always trace the branch with fake tensor
|
|
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
|
fake_mode_det = detect_fake_mode(inputs)
|
|
fake_mode: AbstractContextManager = nullcontext()
|
|
tracing_mode = "fake"
|
|
if fake_mode_det is not None:
|
|
fake_mode = fake_mode_det
|
|
tracing_mode = "real"
|
|
|
|
# Note: we need to turn off proxy tensor mode to avoid tracing infra
|
|
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
|
|
# as fake tensor.
|
|
with fake_mode, disable_proxy_modes_tracing():
|
|
gm = make_fx(
|
|
fn,
|
|
tracing_mode=tracing_mode,
|
|
pre_dispatch=pre_dispatch,
|
|
_error_on_data_dependent_ops=False,
|
|
)(*inputs)
|
|
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: # type: ignore[attr-defined]
|
|
insert_deferred_runtime_asserts(
|
|
gm,
|
|
fake_mode.shape_env, # type: ignore[attr-defined]
|
|
"hoo_maybe_fake_tracing",
|
|
export=True, # type: ignore[attr-defined]
|
|
)
|
|
return gm
|
|
|
|
|
|
def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
|
try:
|
|
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
|
|
except UnsupportedAliasMutationException:
|
|
# this can happen when nested cond_op is
|
|
# functionalized
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
example_inputs = [
|
|
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
|
]
|
|
(
|
|
inp_inp_alias_map,
|
|
inp_out_alias_map,
|
|
out_out_alias_map,
|
|
inp_mutation,
|
|
) = check_input_alias_and_mutation(gm, example_inputs)
|
|
return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation
|
|
|
|
|
|
def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
|
|
if any(len(a) > 0 for a in aliases):
|
|
# TODO: Investigate here further which node is exactly aliasing
|
|
raise RuntimeError(
|
|
f"{name} where aliases appear. "
|
|
+ f"In particular, these inputs \
|
|
{set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401
|
|
+ "get aliased. Please ensure that this doesn't happen."
|
|
)
|
|
if len(input_mutations):
|
|
# TODO: Investigate here further which node is exactly mutating the inputs
|
|
raise RuntimeError(
|
|
f"{name} where the inputs are mutated. "
|
|
+ f"In particular, these nodes are mutating the inputs \
|
|
{set(el for el in input_mutations)}." # noqa: C401
|
|
+ "Please ensure that this doesn't happen."
|
|
)
|
|
|
|
|
|
def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
|
|
(
|
|
(_, _, _),
|
|
inp_mutation,
|
|
) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
|
|
|
return len(inp_mutation) > 0
|
|
|
|
|
|
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
|
(
|
|
(
|
|
inp_inp_alias_map,
|
|
inp_out_alias_map,
|
|
out_out_alias_map,
|
|
),
|
|
inp_mutation,
|
|
) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
|
return (
|
|
any(
|
|
(
|
|
len(inp_inp_alias_map) > 0,
|
|
len(inp_out_alias_map) > 0,
|
|
len(out_out_alias_map) > 0,
|
|
)
|
|
),
|
|
len(inp_mutation) > 0,
|
|
)
|
|
|
|
|
|
def _collect_fake_inputs(inputs):
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
|
|
# Get the example values of the inputs.
|
|
inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = []
|
|
for inp in inputs:
|
|
if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)):
|
|
inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp
|
|
if hasattr(inp, "meta"):
|
|
val = inp.meta["example_value"]
|
|
if isinstance(val, torch.Tensor):
|
|
if torch._C._functorch.is_batchedtensor(
|
|
val
|
|
) or torch._C._functorch.is_functionaltensor(val):
|
|
# This case is for batched or functional tensors
|
|
# Unwrap the tensors
|
|
while torch._C._functorch.is_batchedtensor(
|
|
val
|
|
) or torch._C._functorch.is_functionaltensor(val):
|
|
val = torch._C._functorch.get_unwrapped(val)
|
|
assert isinstance(val, FakeTensor)
|
|
inputs_fake.append(val)
|
|
else:
|
|
# This is the standard case of a TensorVariable
|
|
assert isinstance(val, FakeTensor)
|
|
inputs_fake.append(val)
|
|
else:
|
|
# This case is for SymInts and other non-Tensor elements
|
|
assert not isinstance(val, torch.Tensor)
|
|
inputs_fake.append(val)
|
|
else:
|
|
# This case is for ints
|
|
assert isinstance(inp, int)
|
|
inputs_fake.append(inp)
|
|
|
|
return inputs_fake
|
|
|
|
|
|
def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
|
|
aliases, inp_mutation = has_potential_input_alias_or_mutation(
|
|
graph_module, inputs_fake, pre_dispatch=pre_dispatch
|
|
)
|
|
if aliases:
|
|
raise RuntimeError(f"{name} might be aliasing the input or the output!") # noqa: F541
|
|
if inp_mutation:
|
|
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
|
|
|
|
|
|
def unique_graph_id(proxy_mode, prefix):
|
|
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
|
|
# There are probably better ways - I know that create_arg has some self incrementing name
|
|
# magic to it, but since we explicitly have to get the name for register_module,
|
|
# I was not sure how to do that. This kinda simulates it.
|
|
return unique_graph_name_with_root(proxy_mode.tracer.root, prefix)
|
|
|
|
|
|
def unique_graph_name_with_root(
|
|
root: torch.fx.GraphModule, prefix: str
|
|
) -> tuple[int, str]:
|
|
next_name = None
|
|
i = 0
|
|
while not next_name:
|
|
candidate = f"{prefix}_{i}"
|
|
if hasattr(root, candidate):
|
|
i += 1
|
|
else:
|
|
next_name = candidate
|
|
return i, next_name
|
|
|
|
|
|
def _from_fun(t):
|
|
from torch._functorch.aot_autograd import from_fun
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
if t.dtype != torch.bool:
|
|
return torch.empty_strided(
|
|
t.size(),
|
|
t.stride(),
|
|
dtype=t.dtype,
|
|
requires_grad=t.requires_grad,
|
|
device=t.device,
|
|
)
|
|
else:
|
|
# clone of a functional tensor produces a functional tensor
|
|
# but we want to avoid it so we clone a non-functional version
|
|
maybe_unfunc_t = t
|
|
if isinstance(t, FunctionalTensor):
|
|
torch._sync(t)
|
|
maybe_unfunc_t = from_fun(t)
|
|
elif torch._is_functional_tensor(t):
|
|
# need to handle both types of functionalization here:
|
|
# these are the tensors that came from the user,
|
|
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
|
torch._sync(t)
|
|
maybe_unfunc_t = torch._from_functional_tensor(t)
|
|
return maybe_unfunc_t.clone()
|
|
return t
|
|
|
|
|
|
def clone_outputs_aliasing_inputs(args):
|
|
input_storage = {
|
|
StorageWeakRef(arg._typed_storage())
|
|
for arg in args
|
|
if isinstance(arg, torch.Tensor)
|
|
}
|
|
|
|
def maybe_clone(t):
|
|
if (
|
|
isinstance(t, torch.Tensor)
|
|
and StorageWeakRef(t._typed_storage()) in input_storage
|
|
):
|
|
return t.clone()
|
|
return t
|
|
|
|
return maybe_clone
|
|
|
|
|
|
def prepare_fw_with_masks(fn):
|
|
def fw_with_masks(*args):
|
|
fw_out = fn(*args)
|
|
return fw_out, [
|
|
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
|
|
for ret in fw_out
|
|
]
|
|
|
|
return fw_with_masks
|
|
|
|
|
|
def prepare_fw_with_masks_all_requires_grad(fn):
|
|
def fw_with_masks(*args):
|
|
fw_out = fn(*args)
|
|
# Note [force all outputs to be require grad]
|
|
# Instead of using the original fn, we set the output of original
|
|
# fn to all require grad. This is consistent with the behavior
|
|
# of autograd.Function, where if any one of the inputs requires grad
|
|
# all output will be require grad. This also makes the downstream
|
|
# require_gradness reasoning much easier.
|
|
if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args):
|
|
fw_out = pytree.tree_map_only(
|
|
torch.Tensor,
|
|
lambda x: x.requires_grad_(True) if x.dtype.is_floating_point else x,
|
|
fw_out,
|
|
)
|
|
return fw_out, pytree.tree_map_only(
|
|
torch.Tensor, lambda x: x.requires_grad, fw_out
|
|
)
|
|
|
|
return fw_with_masks
|
|
|
|
|
|
# This function replaces None gradients with all-zero gradients.
|
|
# `None` gradients are problematic for CUDA graphs. Those gradients are
|
|
# replaced with an all-zero tensor for better optimization
|
|
def unmask_none_gradients(grads, operands):
|
|
allowed_types = (torch.Tensor, int, torch.SymInt)
|
|
assert all(isinstance(o, allowed_types) for o in operands), (
|
|
f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
|
|
)
|
|
|
|
unmasked_grads = []
|
|
for g, o in zip(grads, operands):
|
|
if g is not None:
|
|
unmasked_grads.append(g)
|
|
else:
|
|
# In case the operand is an int or a torch.SymInt, return None
|
|
# This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed
|
|
# as additional arguments
|
|
unmasked_grads.append(
|
|
torch.zeros_like(o) if isinstance(o, torch.Tensor) else None
|
|
)
|
|
|
|
return unmasked_grads
|
|
|
|
|
|
def _maybe_fake_prop_ignore_unbacked(fn, args):
|
|
with ExitStack() as ctx_stack:
|
|
if (fake_mode := detect_fake_mode(args)) is not None:
|
|
ctx_stack.enter_context(fake_mode)
|
|
if fake_mode.shape_env is not None:
|
|
ctx_stack.enter_context(
|
|
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
|
)
|
|
return fn(*args)
|
|
|
|
|
|
def redirect_to_mode(hop: OperatorBase, mode):
|
|
"""Utility for redispatching HOP to underlying mode
|
|
|
|
Args:
|
|
hop: The HOP to redispatch
|
|
mode: The mode to redispatch to
|
|
|
|
Returns:
|
|
A decorated function that implements the HOP for the given mode
|
|
"""
|
|
|
|
@hop.py_impl(mode)
|
|
def impl(mode, *args, **kwargs):
|
|
return mode.__torch_dispatch__(hop, [], args, kwargs)
|
|
|
|
return impl
|
|
|
|
|
|
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
|
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
|
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
|
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
|
|
|
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
|
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
|
# added when required. Will encounter two problems if we don't suspend functionalization:
|
|
#
|
|
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
|
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
|
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
|
# fetch the proxy for the inputs and fail to capture any operations on them.
|
|
#
|
|
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
|
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
|
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
|
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
|
# Instead, it will create _tensor_constant as output.
|
|
|
|
dummy_aot_config = AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
example_grad = [_from_fun(out) for out in fw_outputs]
|
|
num_grads = len(example_grad)
|
|
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
|
|
|
|
def joint_fn(*joint_operands_grads):
|
|
if use_output_and_grad_bw:
|
|
grads = joint_operands_grads[0]
|
|
inputs = joint_operands_grads[1][-1:]
|
|
else:
|
|
grads = joint_operands_grads[:num_grads]
|
|
inputs = joint_operands_grads[num_grads:]
|
|
|
|
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
|
|
_, grads = joint(
|
|
list(inputs),
|
|
[grad for grad in grads if grad is not None and grad.requires_grad],
|
|
)
|
|
|
|
# Unmask None gradients to all-zero gradients
|
|
unmasked_grads = unmask_none_gradients(grads, inputs)
|
|
|
|
# In order to keep map functional for backward graph,
|
|
# we clone outputs that are aliasing inputs
|
|
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
|
|
|
|
return pytree.tree_map(maybe_clone, unmasked_grads)
|
|
|
|
if use_output_and_grad_bw:
|
|
example_xs_out = list(fw_inputs) + list(fw_outputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
(list(example_grad), list(example_xs_out))
|
|
)
|
|
else:
|
|
example_xs_out = list(fw_inputs)
|
|
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
|
*(list(example_grad) + list(example_xs_out))
|
|
)
|
|
|
|
return fw_graph, joint_graph
|
|
|
|
|
|
def _unstack_pytree(xs):
|
|
flat_xs, inspec = pytree.tree_flatten(xs)
|
|
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
|
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
|
|
|
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
|
raise RuntimeError(
|
|
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
|
)
|
|
|
|
a = zip(*flat_xs)
|
|
|
|
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
|
|
return pytrees
|
|
|
|
|
|
def _stack_pytree(pytrees):
|
|
flat_out = []
|
|
out_spec = None
|
|
for pt in pytrees:
|
|
flat_pt, out_spec = pytree.tree_flatten(pt)
|
|
flat_out.append(flat_pt)
|
|
assert out_spec is not None
|
|
b = zip(*flat_out)
|
|
stacked_out = []
|
|
for leaves in b:
|
|
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
|
stacked_out.append(torch.stack(leaves))
|
|
elif all(leaf is None for leaf in leaves):
|
|
# Backward graph can return None output when forward inputs doesn't require grad.
|
|
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
|
# therefore we need to deal with None output.
|
|
stacked_out.append(None) # type: ignore[arg-type]
|
|
else:
|
|
raise RuntimeError(f"Cannot stack {leaves}.")
|
|
return pytree.tree_unflatten(stacked_out, out_spec)
|
|
|
|
|
|
# We cannot call save_for_backward for symints. This helper function
|
|
# can be used to save symints as direct attributes of ctx in autograd.Function.
|
|
#
|
|
# For example, if args = (x, y, s0, z, s1),
|
|
# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
|
|
# partitioned_args[0] = (x, y, z)
|
|
# partitioned_args[1] = (s0, s1)
|
|
# pos = (0, 0, 1, 0, 1)
|
|
# pos list keeps track of which partition the args
|
|
# is partitioned into in order to recover it in saved_tensors_and_symints.
|
|
#
|
|
# In saved_tensors_and_symints, we can recover the original args by:
|
|
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
|
|
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
|
def save_tensors_and_symints_for_backward(ctx, args):
|
|
assert all(
|
|
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
|
|
), args
|
|
partitioned_args: list[Any] = [[], []]
|
|
pos = []
|
|
for arg in args:
|
|
idx = 0 if isinstance(arg, torch.Tensor) else 1
|
|
partitioned_args[idx].append(arg)
|
|
pos.append(idx)
|
|
|
|
assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
|
|
assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
|
|
ctx.save_for_backward(*partitioned_args[0])
|
|
ctx.sym_int_args = partitioned_args[1]
|
|
ctx.pos = pos
|
|
|
|
|
|
def saved_tensors_and_symints(ctx):
|
|
args = []
|
|
t_idx = 0
|
|
s_idx = 0
|
|
saved_tensors = ctx.saved_tensors
|
|
for p in ctx.pos:
|
|
if p == 0:
|
|
args.append(saved_tensors[t_idx])
|
|
t_idx += 1
|
|
else:
|
|
args.append(ctx.sym_int_args[s_idx])
|
|
s_idx += 1
|
|
assert t_idx + s_idx == len(ctx.pos)
|
|
return tuple(args)
|
|
|
|
|
|
def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]:
|
|
assert sum(chunk_sizes) == len(iterable), (
|
|
"the sum of all chunks needs to match the length of the iterable."
|
|
)
|
|
elements = []
|
|
idx = 0
|
|
for size in chunk_sizes:
|
|
elements.append(iterable[idx : idx + size])
|
|
idx += size
|
|
return elements
|
|
|
|
|
|
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
|
"""
|
|
For a fn that accepts flat inputs and returns flat outputs:
|
|
fw_out = fn(*args),
|
|
this function returns:
|
|
grad_args = bw_fn(*args_and_grad_output)
|
|
with the following invariants:
|
|
1. args + fw_out has an 1-1 correspondence to args_and_grad_output
|
|
2. grad_args has an 1-1 corresponsence to args
|
|
3. for tensor arg whose requires_grad is False, its corresponding grad in
|
|
grad_args will be a zero tensor with the same shape.
|
|
"""
|
|
|
|
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
|
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
|
|
|
|
dummy_aot_config = AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
n_primals = len(args)
|
|
|
|
bw_fn = create_joint(
|
|
prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config
|
|
)
|
|
|
|
def flat_fn(*args_and_grad_outs):
|
|
primals = args_and_grad_outs[:n_primals]
|
|
tangents = args_and_grad_outs[n_primals:]
|
|
grad_args = bw_fn(primals, tangents)[1]
|
|
assert len(args) == len(grad_args)
|
|
|
|
maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs)
|
|
|
|
return [
|
|
(
|
|
torch.zeros_like(arg)
|
|
if isinstance(arg, torch.Tensor) and grad is None
|
|
else maybe_clone(grad)
|
|
)
|
|
for grad, arg in zip(grad_args, primals)
|
|
]
|
|
|
|
return flat_fn
|
|
|
|
|
|
def get_dummy_aot_autograd_config():
|
|
from torch._functorch.aot_autograd import AOTConfig
|
|
|
|
return AOTConfig(
|
|
fw_compiler=None, # type: ignore[arg-type]
|
|
bw_compiler=None, # type: ignore[arg-type]
|
|
partition_fn=None, # type: ignore[arg-type]
|
|
decompositions={},
|
|
num_params_buffers=0,
|
|
aot_id=0,
|
|
keep_inference_input_mutations=False,
|
|
)
|
|
|
|
|
|
# Slices off the first element of a given dimension
|
|
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
|
return torch.select_copy(t, dim, 0)
|
|
|
|
|
|
# Returns a mask whether a list element is a tensor or not
|
|
def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]:
|
|
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
|
|
|
|
|
|
def mask_list(
|
|
mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
|
|
) -> list[Any]:
|
|
# Masks elements on an `inp` list.
|
|
# If other is None, then the elements of the `inp` list where the mask is False are removed
|
|
# If other is not None, then the elements of the `inp` list where the mask is False are
|
|
# replaced with the elements of the `other` list
|
|
assert len(mask) == len(inp), (
|
|
"The length of the mask needs to be identical to the length of the input"
|
|
)
|
|
if other is not None:
|
|
assert len(inp) == len(other), (
|
|
"If an input and an other list is provided, they need to have the same length"
|
|
)
|
|
return [i if m else o for m, i, o in zip(mask, inp, other)]
|
|
else:
|
|
return [i for m, i in zip(mask, inp) if m]
|
|
|
|
|
|
def first_slice_copy_with_grad(li: Iterable[Any]) -> list[Any]:
|
|
# First_slice_copy does not keep the original requires_grad flag,
|
|
# but we need it for materialize_as_graph
|
|
# in order to compute the correct gradients
|
|
# The reason why first_slice_copy doesn't keep requires_grad flag is
|
|
# because it's called in torch.autograd.Function.backward/forward.
|
|
slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
|
|
return slc
|
|
|
|
|
|
# Reports the difference between meta of two tensors in a string
|
|
def diff_tensor_meta(
|
|
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
|
) -> list[str]:
|
|
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
|
|
|
pair_diffs = []
|
|
for meta_name in TensorMetadata._fields:
|
|
if not check_grad and meta_name == "requires_grad":
|
|
continue
|
|
val1 = getattr(meta1, meta_name)
|
|
val2 = getattr(meta2, meta_name)
|
|
try:
|
|
if val1 != val2:
|
|
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
|
except GuardOnDataDependentSymNode as _:
|
|
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
|
continue
|
|
return pair_diffs
|
|
|
|
|
|
# Note [lifted arg types in hop]
|
|
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
|
# This has implications for the types of lifted args for different dispatch keys:
|
|
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
|
|
# lifted args because it's on the path of torch.compile(dynamic=True).
|
|
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
|
|
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
|
|
# hops may receive int inputs from the shape of outer tensor inputs.
|
|
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
|
|
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
|
|
allowed_types = (torch.Tensor, int, torch.SymInt)
|
|
assert all(
|
|
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
|
), (
|
|
f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
|
)
|
|
|
|
|
|
# TODO: Return a more detailed information as to which node
|
|
# causes a mutation or an alias. This may requires a per operator tensor version checking
|
|
def check_input_alias_and_mutation(
|
|
gm: torch.fx.GraphModule,
|
|
fake_args: list[FakeTensor],
|
|
) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]:
|
|
(
|
|
inp_inp_alias_map,
|
|
inp_out_alias_map,
|
|
out_out_alias_map,
|
|
mutated_inputs,
|
|
) = check_input_alias_and_mutation_return_outputs(gm, fake_args)[:-1]
|
|
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
|
|
|
|
|
def _tensor_storage(t) -> StorageWeakRef:
|
|
return StorageWeakRef(t._typed_storage())
|
|
|
|
|
|
def check_input_alias_and_mutation_return_outputs(
|
|
gm: torch.fx.GraphModule,
|
|
fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]],
|
|
) -> tuple[
|
|
dict[int, int],
|
|
dict[int, int],
|
|
dict[int, int],
|
|
list[int],
|
|
Union[tuple[Any, ...], list[Any]],
|
|
]:
|
|
# This function can be called under autograd, functional, proxy and fake tensor mode.
|
|
# We need to return either a fake tensor or a real tensor depending on the mode.
|
|
# to detect the input mutation/aliasing.
|
|
with (
|
|
disable_proxy_modes_tracing(),
|
|
disable_functional_mode(),
|
|
suspend_functionalization(),
|
|
):
|
|
|
|
def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor:
|
|
if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t):
|
|
return torch.empty_strided(
|
|
t.size(),
|
|
t.stride(),
|
|
dtype=t.dtype,
|
|
requires_grad=t.requires_grad,
|
|
device=t.device,
|
|
)
|
|
return t
|
|
|
|
fake_args = pytree.tree_map_only(
|
|
torch.Tensor, _from_functional_tensor, fake_args
|
|
)
|
|
# We want to disable active functional, proxy and fake modes if any.
|
|
# to create a encapsulated environment for fake tensor prop
|
|
with torch.utils._python_dispatch._disable_current_modes():
|
|
"""This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias
|
|
in the graph module gm. It checks whether input tensor versions have
|
|
changed after run gm once to detect mutation and checks tensor storage
|
|
to detect alias.
|
|
"""
|
|
|
|
def _tensor_version(t) -> Optional[int]:
|
|
if isinstance(t, torch.Tensor):
|
|
if not isinstance(t, FakeTensor):
|
|
raise RuntimeError("Only fake tensor is allowed")
|
|
return t._version
|
|
return None
|
|
|
|
def _get_shape_env(
|
|
fake_args,
|
|
) -> torch.fx.experimental.symbolic_shapes.ShapeEnv:
|
|
# detect_fake_mode requires there could be only one active fake mode. This
|
|
# restricts the usage of this function because the global TracingContext
|
|
# has a persistent fake mode but fake tensors can be created
|
|
# outside of the tracing context (e.g. in testing).
|
|
# Instead, we just look at fake_args fake tensor mode
|
|
for arg in fake_args:
|
|
if isinstance(arg, FakeTensor) and arg.fake_mode.shape_env is not None:
|
|
return arg.fake_mode.shape_env
|
|
return torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
|
|
|
# Clone the fake args to avoid mutating the original fake args
|
|
with ExitStack() as ctx_stack:
|
|
# We need to reuse prev_fake_mode's shape env to resolve
|
|
# the runtime assertions for unbacked symbols.
|
|
new_fake_mode = torch._subclasses.FakeTensorMode(
|
|
shape_env=_get_shape_env(fake_args),
|
|
# In executorch, there's an scalar_to_tensor pass that turns scalar inputs into a tensor constant
|
|
# e.g. add(a, 1) 1 is turned into a tensor, which becomes a constant tensor attribute in the graph.
|
|
# We allow non fake inputs for this purpose. This is fine for mutation detection purpose:
|
|
# inputs are all fake and all mutations/aliasing are still detected.
|
|
allow_non_fake_inputs=True,
|
|
)
|
|
# We need to temporarily turn inference_mode off because
|
|
# under inference mode, tensor version counter is not tracked.
|
|
no_inference_mode_ctx = torch.inference_mode(False)
|
|
ctx_stack.enter_context(new_fake_mode)
|
|
ctx_stack.enter_context(no_inference_mode_ctx)
|
|
if new_fake_mode.shape_env is not None:
|
|
ctx_stack.enter_context(
|
|
new_fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
|
)
|
|
|
|
# create new fake tensors in new fake mode to avoid mutating original tensors
|
|
cloned = [
|
|
torch.empty_strided(
|
|
arg.size(),
|
|
arg.stride(),
|
|
dtype=arg.dtype,
|
|
device=arg.device,
|
|
requires_grad=arg.requires_grad,
|
|
layout=arg.layout,
|
|
)
|
|
if isinstance(arg, torch.Tensor)
|
|
else arg
|
|
for arg in fake_args
|
|
]
|
|
before = [_tensor_version(arg) for arg in cloned]
|
|
outputs = gm(*cloned)
|
|
outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs
|
|
after = [_tensor_version(arg) for arg in cloned]
|
|
mutated_inputs = [
|
|
i for i, (v1, v2) in enumerate(zip(before, after)) if v1 != v2
|
|
]
|
|
# We need to analyze the original fake_args to detect
|
|
# inp-inp alias.
|
|
inp_storage_map = {
|
|
_tensor_storage(inp): i
|
|
for i, inp in enumerate(fake_args)
|
|
if isinstance(inp, torch.Tensor)
|
|
}
|
|
inp_inp_alias_map = {
|
|
i: inp_storage_map[_tensor_storage(inp)]
|
|
for i, inp in enumerate(fake_args)
|
|
if isinstance(inp, torch.Tensor)
|
|
and inp_storage_map[_tensor_storage(inp)] != i
|
|
}
|
|
out_storage_map = {
|
|
_tensor_storage(out): i
|
|
for i, out in enumerate(outputs)
|
|
if isinstance(out, torch.Tensor)
|
|
}
|
|
out_out_alias_map = {
|
|
i: out_storage_map[_tensor_storage(out)]
|
|
for i, out in enumerate(outputs)
|
|
if isinstance(out, torch.Tensor)
|
|
and out_storage_map[_tensor_storage(out)] != i
|
|
}
|
|
inp_out_alias_map = {
|
|
i: out_storage_map[_tensor_storage(inp)]
|
|
for i, inp in enumerate(cloned)
|
|
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
|
|
}
|
|
return (
|
|
inp_inp_alias_map,
|
|
inp_out_alias_map,
|
|
out_out_alias_map,
|
|
mutated_inputs,
|
|
outputs,
|
|
)
|
|
|
|
|
|
registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {}
|
|
|
|
|
|
F = TypeVar("F", bound=Callable)
|
|
|
|
|
|
@overload
|
|
def register_fake(hop, fn: None = None) -> Callable[[F], F]: ...
|
|
|
|
|
|
@overload
|
|
def register_fake(hop, fn: F) -> F: ...
|
|
|
|
|
|
def register_fake(hop, fn=None):
|
|
"""
|
|
Register a fake function for a HOP. This is conceptually equivalent of the
|
|
register_fake utility for the custom ops. The registered function is called
|
|
inside the fake_tensor _dispatch_impl.
|
|
"""
|
|
assert hop not in registered_hop_fake_fns
|
|
|
|
def register(func: F) -> F:
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
redirect_to_mode(hop, FakeTensorMode)
|
|
|
|
registered_hop_fake_fns[hop] = func
|
|
return func
|
|
|
|
if fn is None:
|
|
return register
|
|
return register(fn)
|
|
|
|
|
|
class FunctionalizeCtxWrapper:
|
|
"""
|
|
This is a dummy wrapper to facilitate fake tensor caching.
|
|
|
|
For AOT Dispatcher metadata collection pass, HOPs go from functionalization
|
|
key to fake tensor key. The functionalization key wraps the subgraphs in a
|
|
function, which changes from call to call even though the subgraph might
|
|
still be same.
|
|
|
|
To enable fake tensor caching, we just wrap the ctx and subgraph in this
|
|
class and then use the subgraph as the hash.
|
|
"""
|
|
|
|
# Prevents PYTORCH_TEST_WITH_DYNAMO=1 test failures
|
|
@torch._disable_dynamo
|
|
def __init__(self, ctx, subgraph):
|
|
self.ctx = ctx
|
|
self.subgraph = subgraph
|
|
|
|
def __hash__(self):
|
|
return id(self.subgraph)
|
|
|
|
def __repr__(self):
|
|
return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})"
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
if isinstance(self.subgraph, torch.fx.GraphModule):
|
|
# Running graph with interpreter is needed for propagating the stack_trace
|
|
with fx_traceback.preserve_node_meta():
|
|
return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)(
|
|
*args, **kwargs
|
|
)
|
|
return self.ctx.functionalize(self.subgraph)(*args, **kwargs)
|
|
|
|
|
|
# A wrapper over HigherOrderOperator that also carries its schema
|
|
class HopInstance:
|
|
def __init__(self, op: HigherOrderOperator, schema: HopSchema):
|
|
assert isinstance(op, HigherOrderOperator), op
|
|
self._op = op
|
|
# Using "_" to be consistent with how we access _schema of OpOverload
|
|
self._schema = schema
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self._op(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def create(hop: HigherOrderOperator, *args, **kwargs):
|
|
return HopInstance(hop, hop.gen_schema(*args, **kwargs))
|
|
|
|
|
|
# This call_op can be used to call a HopInstance with
|
|
# flat args and kwargs. We need to make use of the hop's schema's tree_spec
|
|
# to unflatten the args and kwargs before calling the hop.
|
|
def call_op(op: Union[OpOverload, HopInstance], args, kwargs):
|
|
if isinstance(op, OpOverload):
|
|
return op(*args, **kwargs)
|
|
|
|
assert isinstance(op, HopInstance), op
|
|
schema = op._schema
|
|
bound_args = list(args)
|
|
bound_kwargs = {}
|
|
for arg in schema.arguments[len(bound_args) :]:
|
|
assert arg.name in kwargs, (arg.name, kwargs)
|
|
val = kwargs[arg.name]
|
|
if not arg.kwarg_only:
|
|
bound_args.append(val)
|
|
else:
|
|
bound_kwargs[arg.name] = val
|
|
|
|
if schema.tree_spec is not None:
|
|
assert len(bound_args) == len(schema.arguments) and len(bound_kwargs) == 0
|
|
args, kwargs = pytree.tree_unflatten(bound_args, schema.tree_spec)
|
|
return op(*args, **kwargs)
|
|
else:
|
|
assert len(bound_args) + len(bound_kwargs) == len(schema.arguments)
|
|
return op(*bound_args, **bound_kwargs)
|
|
|
|
|
|
def materialize_as_graph(
|
|
fn: Callable,
|
|
args: tuple[Any],
|
|
include_key_set: Optional[torch._C.DispatchKeySet] = None,
|
|
exclude_key_set: Optional[torch._C.DispatchKeySet] = None,
|
|
force_enable_grad=False,
|
|
) -> torch.fx.GraphModule:
|
|
if include_key_set is None:
|
|
include_key_set = torch._C._dispatch_tls_local_include_set()
|
|
if exclude_key_set is None:
|
|
exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
|
|
|
@torch._dynamo.disable(recursive=True, reason=None)
|
|
def _materialize_as_graph_inner():
|
|
with suspend_functionalization(), disable_functional_mode():
|
|
with disable_proxy_modes_tracing():
|
|
unfunc_t = [_from_fun(arg) for arg in args]
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
stack.enter_context(
|
|
torch.utils._python_dispatch._disable_current_modes()
|
|
)
|
|
stack.enter_context(
|
|
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
|
|
)
|
|
if force_enable_grad:
|
|
stack.enter_context(torch.enable_grad())
|
|
# fake_mode is needed because parent tracer's fake_mode might
|
|
# be None but the associated inputs have fake mode or there
|
|
# is a global tracing context with fake mode. We nneed to
|
|
# make sure the fake mode when tracing subgraph is consistent.
|
|
if fake_mode := detect_fake_mode(unfunc_t):
|
|
stack.enter_context(fake_mode)
|
|
return _maybe_reenter_make_fx(fn)(*unfunc_t)
|
|
|
|
gm = _materialize_as_graph_inner()
|
|
assert gm is not None
|
|
return gm
|
|
|
|
|
|
def materialize_callable_in_args(op: HopInstance, args, kwargs):
|
|
schema = op._schema
|
|
hop = op._op
|
|
flat_args, flat_spec = pytree.tree_flatten((args, kwargs))
|
|
|
|
def wrapped_fn(*flat_args):
|
|
return call_op(op, args, kwargs)
|
|
|
|
# We need to trace the higher order op in order to materilaize the callable inputs that
|
|
# are a callable (e.g. after functionalization key)
|
|
gm = reenter_make_fx(wrapped_fn)(*flat_args)
|
|
hop_node = gm.graph.find_nodes(op="call_function", target=hop)[0]
|
|
arg_proxies = pytree.tree_leaves((hop_node.args, hop_node.kwargs))
|
|
assert isinstance(schema, torch._C.FunctionSchema) and len(arg_proxies) == len(
|
|
schema.arguments
|
|
)
|
|
|
|
# call_op preserves ordering of proxies via schema
|
|
materialized_args = []
|
|
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
|
|
if (
|
|
isinstance(proxy, torch.fx.Node)
|
|
and proxy.op == "get_attr"
|
|
and isinstance(getattr(gm, proxy.target), torch.fx.GraphModule) # type: ignore[arg-type]
|
|
):
|
|
assert callable(flat_args[i]), (schema, args, kwargs)
|
|
materialized_args.append(getattr(gm, proxy.target)) # type: ignore[arg-type]
|
|
else:
|
|
materialized_args.append(flat_args[i])
|
|
|
|
return pytree.tree_unflatten(materialized_args, flat_spec)
|
|
|
|
|
|
def has_user_subclass(args, allowed_subclasses):
|
|
"""Check if any tensor arguments are user subclasses.
|
|
|
|
This is used to determine if tensor subclasses should get a chance to run
|
|
their own implementation first before falling back to the default implementation.
|
|
|
|
Args:
|
|
args: Arguments to check (will be flattened with pytree)
|
|
allowed_subclasses: Tuple of allowed subclass types
|
|
|
|
Returns:
|
|
True if user tensor subclasses are found, False otherwise
|
|
"""
|
|
flat_args, _ = pytree.tree_flatten(args)
|
|
|
|
val = any(
|
|
isinstance(a, torch.Tensor)
|
|
and type(a) is not torch.Tensor
|
|
and not isinstance(a, allowed_subclasses)
|
|
for a in flat_args
|
|
)
|
|
return val
|
|
|
|
|
|
def _has_gen_schema(op: HigherOrderOperator):
|
|
# There is an InvokeQuant argument we cannot gen_schema.
|
|
if op is torch.ops.higher_order.invoke_quant_packed:
|
|
return False
|
|
method = "gen_schema"
|
|
return hasattr(type(op), method) and getattr(type(op), method) is not getattr(
|
|
HigherOrderOperator, method
|
|
)
|
|
|
|
|
|
def filter_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]):
|
|
assert len(data) == len(masks)
|
|
return [item for item, keep in zip(data, masks) if keep]
|
|
|
|
|
|
def fill_none_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]):
|
|
data_iter = iter(data)
|
|
return [next(data_iter) if kept else None for kept in masks]
|