mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This hits multi-line logging strings Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/98700 Approved by: https://github.com/voznesenskym
4244 lines
135 KiB
Python
4244 lines
135 KiB
Python
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import re
|
|
import textwrap
|
|
from contextlib import nullcontext
|
|
from enum import Enum
|
|
from functools import partial
|
|
from inspect import signature
|
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
|
|
from unittest.mock import patch
|
|
|
|
import sympy
|
|
from sympy import Expr, Integer
|
|
|
|
import torch._dynamo.config as dynamo_config
|
|
import torch._logging
|
|
|
|
import torch.fx
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.utils import identity
|
|
from torch._prims_common import (
|
|
compute_required_storage_length,
|
|
is_boolean_dtype,
|
|
is_float_dtype,
|
|
make_channels_last_strides_for,
|
|
make_contiguous_strides_for,
|
|
)
|
|
from torch.fx.experimental.symbolic_shapes import FloorDiv
|
|
|
|
from . import config, dependencies
|
|
from .codegen.common import index_prevent_reordering
|
|
from .cuda_properties import get_device_properties
|
|
from .dependencies import extract_read_writes, var_builder
|
|
from .utils import (
|
|
argsort,
|
|
cache_on_self,
|
|
convert_shape_to_inductor,
|
|
convert_shape_to_symint,
|
|
developer_warning,
|
|
sympy_dot,
|
|
sympy_product,
|
|
sympy_subs,
|
|
sympy_symbol,
|
|
)
|
|
from .virtualized import ops, V
|
|
|
|
log = logging.getLogger(__name__)
|
|
indent = functools.partial(textwrap.indent, prefix=" ")
|
|
aten = torch.ops.aten
|
|
|
|
""" [Note: Inductor IR]
|
|
|
|
Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
|
|
lowering is registered to a particular aten operator, and expects inputs that
|
|
correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
|
|
expect Inductor TensorBox inputs.
|
|
|
|
TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
|
|
storage, and sometimes views of another Tensor's storage. Mutating tensor operations
|
|
(such as add_()) affect the underlying storage and any associated views. Other operations
|
|
(such as .t_()) update metadata about the current view but don't modify the underlying storage.
|
|
|
|
To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
|
|
|
|
TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
|
|
output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
|
|
reference View IR or directly reference StorageBox IRs.
|
|
|
|
Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
|
|
may take an existing TensorBox and point it to a new underlying View IR.
|
|
|
|
Tensors that directly own storage are represented as a chain of:
|
|
TensorBox -> StorageBox -> Buffer
|
|
where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
|
|
|
|
If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
|
|
(leaving the old buffer unmodified and functionalizing the operation).
|
|
|
|
Tensors backed by views add one more indirection to the IR.
|
|
TensorBox -> View -> StorageBox -> Buffer
|
|
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
|
|
"""
|
|
|
|
|
|
def validate_ir(node_or_nodes):
|
|
def _check_tensorbox(node):
|
|
# Could expand this to check deeper properties
|
|
# (e.g. TensorBox points to View or StorageBox)
|
|
assert isinstance(
|
|
node,
|
|
(
|
|
DynamicScalar,
|
|
TensorBox,
|
|
RandSeedBuffer,
|
|
sympy.Symbol,
|
|
sympy.core.relational.Relational,
|
|
Expr,
|
|
),
|
|
), f"Found {type(node)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
|
|
|
# Be picky about the accepted data structure (don't use pytree here)
|
|
if isinstance(node_or_nodes, (List, Tuple)):
|
|
for node in node_or_nodes:
|
|
_check_tensorbox(node)
|
|
else:
|
|
_check_tensorbox(node_or_nodes)
|
|
|
|
|
|
def inverse_reorder(order):
|
|
inv_order = dict(zip(order, range(len(order))))
|
|
|
|
def reindex(index):
|
|
assert len(index) == len(inv_order)
|
|
return [index[inv_order[i]] for i in range(len(index))]
|
|
|
|
return reindex
|
|
|
|
|
|
def same_reorder(order):
|
|
def reindex(index):
|
|
assert len(index) == len(order)
|
|
return [index[order[i]] for i in range(len(index))]
|
|
|
|
return reindex
|
|
|
|
|
|
def fuse_reindexing(reindex1, reindex2):
|
|
def reindex(index):
|
|
return reindex1(reindex2(index))
|
|
|
|
return reindex
|
|
|
|
|
|
def stride_order2fill_order(order):
|
|
"""
|
|
Convert stride order to fill order
|
|
For channel last format,
|
|
stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
|
|
"""
|
|
lookup = {pos: idx for idx, pos in enumerate(order)}
|
|
fill_order = [lookup[i] for i in range(len(order))]
|
|
return fill_order
|
|
|
|
|
|
def get_stride_order(seq):
|
|
"""
|
|
Convert strides to stride order
|
|
"""
|
|
sorted_idx = argsort(seq)
|
|
out = [None for _ in range(len(seq))]
|
|
for i, elem in enumerate(sorted_idx):
|
|
out[elem] = i
|
|
return out
|
|
|
|
|
|
def ir_node_to_tensor(x, guard_shape=True):
|
|
if x is None:
|
|
return None
|
|
if not guard_shape:
|
|
shape_fn = V.graph.sizevars.size_hint
|
|
else:
|
|
shape_fn = identity
|
|
size = [shape_fn(s) for s in x.get_size()]
|
|
if is_storage_and_layout(x):
|
|
stride = [shape_fn(s) for s in x.get_layout().stride]
|
|
else:
|
|
stride = make_contiguous_strides_for(size)
|
|
dtype = x.get_dtype()
|
|
device = x.get_device()
|
|
size = convert_shape_to_symint(size)
|
|
stride = convert_shape_to_symint(stride)
|
|
t = torch.empty_strided(
|
|
size=size, stride=stride, dtype=dtype, device=device
|
|
).zero_()
|
|
return t
|
|
|
|
|
|
class ModularIndexing(sympy.Function):
|
|
"""
|
|
ModularIndexing(a, b, c) => (a // b) % c
|
|
"""
|
|
|
|
nargs = (3,)
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, divisor, modulus):
|
|
if base == 0 or modulus == 1:
|
|
return sympy.Integer(0)
|
|
|
|
if (
|
|
isinstance(base, sympy.Integer)
|
|
and isinstance(divisor, sympy.Integer)
|
|
and isinstance(modulus, sympy.Integer)
|
|
):
|
|
return (base // divisor) % modulus
|
|
|
|
if divisor != 1:
|
|
gcd = sympy.gcd(base, divisor)
|
|
if gcd != 1:
|
|
return ModularIndexing(base / gcd, divisor / gcd, modulus)
|
|
|
|
if isinstance(base, sympy.Add):
|
|
new_terms = []
|
|
all_positive = True
|
|
for term in base.args:
|
|
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
|
|
if (isinstance(term, sympy.Integer) and term < 0) or (
|
|
isinstance(term, sympy.Mul)
|
|
and isinstance(term.args[0], sympy.Integer)
|
|
and term.args[0] < 0
|
|
):
|
|
# workaround for https://github.com/openai/triton/issues/619,
|
|
# if there are negative terms, // produces wrong result
|
|
# TODO if https://github.com/openai/triton/issues/619 is fixed
|
|
# this optimization would become valid
|
|
all_positive = False
|
|
break
|
|
else:
|
|
new_terms.append(term)
|
|
|
|
if len(new_terms) != len(base.args) and all_positive:
|
|
return ModularIndexing(sum(new_terms), divisor, modulus)
|
|
|
|
if isinstance(base, FloorDiv):
|
|
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
|
|
|
|
|
|
class CleanDiv(FloorDiv):
|
|
"""
|
|
Div where we can assume no rounding.
|
|
This is to enable future optimizations.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class CeilDiv(sympy.Function):
|
|
"""
|
|
Div used in indexing that rounds up.
|
|
"""
|
|
|
|
is_integer = True
|
|
|
|
def __new__(cls, base, divisor):
|
|
if sympy.gcd(base, divisor) == divisor:
|
|
return CleanDiv(base, divisor)
|
|
else:
|
|
return FloorDiv(base + (divisor - 1), divisor)
|
|
|
|
|
|
def get_device_type(x):
|
|
if getattr(x, "get_device", None):
|
|
return get_device_type(x.get_device())
|
|
if isinstance(x, torch.device):
|
|
return x.type
|
|
return None
|
|
|
|
|
|
def is_triton(x):
|
|
return get_device_type(x) == "cuda"
|
|
|
|
|
|
def is_cpu(x):
|
|
return get_device_type(x) == "cpu"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class IRNode:
|
|
_current_origins: ClassVar[Set[Any]] = set()
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def current_origins(origins: Set[torch.fx.Node]):
|
|
old = IRNode._current_origins
|
|
IRNode._current_origins = old | origins
|
|
try:
|
|
yield
|
|
finally:
|
|
IRNode._current_origins = old
|
|
|
|
def __post_init__(self):
|
|
self.origins = set(self._current_origins)
|
|
|
|
def common_repr(self):
|
|
origins = f"origins={getattr(self, 'origins', '')}"
|
|
if len(origins) > 64:
|
|
# this can get *very* long
|
|
origins = f"{origins[:61]}..."
|
|
return [origins]
|
|
|
|
def str_helper(self, lines):
|
|
lines = lines + self.common_repr()
|
|
lines = indent(",\n".join(map(str, lines)))
|
|
return f"{type(self).__name__}(\n{lines}\n)"
|
|
|
|
def is_user_of(self, name):
|
|
return any(name == dep.name for dep in self.get_reads())
|
|
|
|
def get_numel(self):
|
|
return sympy_product(self.get_size())
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Loops(IRNode):
|
|
device: torch.device
|
|
dtype: torch.dtype
|
|
inner_fn: Callable
|
|
ranges: List[Expr]
|
|
|
|
def __str__(self, names=("ranges",)):
|
|
return self.str_helper(
|
|
[
|
|
f"'{self.device.type}'",
|
|
str(self.dtype),
|
|
self.inner_fn_str(),
|
|
]
|
|
+ [f"{name}={getattr(self, name)}" for name in names]
|
|
)
|
|
|
|
__repr__ = __str__
|
|
|
|
def get_dtype(self):
|
|
return self.dtype
|
|
|
|
def get_device(self):
|
|
return self.device
|
|
|
|
def get_size(self):
|
|
return self.ranges
|
|
|
|
def is_extern(self):
|
|
return False
|
|
|
|
@classmethod
|
|
def create(cls, *args, **kwargs):
|
|
return TensorBox.create(cls(*args, **kwargs))
|
|
|
|
@staticmethod
|
|
def _index(ranges, prefix="i"):
|
|
return [
|
|
sympy.Integer(0) if s == 1 else sympy_symbol(f"{prefix}{n}")
|
|
for n, s in enumerate(ranges)
|
|
]
|
|
|
|
@cache_on_self
|
|
def inner_fn_str(self):
|
|
index = self._index(self.ranges)
|
|
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)
|
|
|
|
def is_zero_elements(self):
|
|
return any(r == 0 for r in self.ranges)
|
|
|
|
@cache_on_self
|
|
def get_reads(self):
|
|
with patch.object(FlexibleLayout, "allow_indexing", True):
|
|
if self.get_reduction_type():
|
|
return extract_read_writes(
|
|
self.make_loader(),
|
|
self.get_size(),
|
|
self.get_reduction_size(),
|
|
).reads
|
|
else:
|
|
return extract_read_writes(
|
|
self.make_loader(),
|
|
self.get_size(),
|
|
).reads
|
|
|
|
|
|
class Pointwise(Loops):
|
|
def make_loader(self):
|
|
return self.inner_fn
|
|
|
|
def get_reduction_size(self):
|
|
return []
|
|
|
|
def get_reduction_type(self):
|
|
return None
|
|
|
|
def store_output(self, output_name, indexer, vars):
|
|
return ops.store(output_name, indexer(vars), self.inner_fn(vars))
|
|
|
|
def constant_to_device(self, device):
|
|
"""Move this to a given device. Requires that all reads are to constants."""
|
|
loader = self.make_loader()
|
|
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
|
|
return Pointwise(device, self.dtype, loader, self.ranges)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Scatter(Pointwise):
|
|
output_indexer: Callable[[List[Expr]], Expr]
|
|
scatter_mode: Optional[str] = None
|
|
|
|
def constant_to_device(self, device):
|
|
"""Move this to a given device. Requires that all reads are to constants."""
|
|
loader = self.make_loader()
|
|
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
|
|
return Scatter(
|
|
device,
|
|
self.dtype,
|
|
loader,
|
|
self.ranges,
|
|
self.output_indexer,
|
|
self.scatter_mode,
|
|
)
|
|
|
|
def store_output(self, output_name, indexer, vars):
|
|
return ops.store(
|
|
output_name,
|
|
indexer(self.output_indexer(vars)),
|
|
self.inner_fn(vars),
|
|
mode=self.scatter_mode,
|
|
)
|
|
|
|
|
|
class ReductionHint(Enum):
|
|
INNER = 0
|
|
OUTER = 1
|
|
OUTER_TINY = 2
|
|
DEFAULT = 3
|
|
|
|
|
|
class TileHint(Enum):
|
|
SQUARE = 0
|
|
DEFAULT = 1
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Reduction(Loops):
|
|
reduction_ranges: List[Expr]
|
|
reduction_type: str
|
|
# self.dtype represents the dst dtype
|
|
src_dtype: torch.dtype
|
|
reduction_hint: ReductionHint
|
|
|
|
def __str__(self):
|
|
return Loops.__str__(
|
|
self, names=("ranges", "reduction_ranges", "reduction_type")
|
|
)
|
|
|
|
__repr__ = __str__
|
|
|
|
def get_reduction_size(self):
|
|
return self.reduction_ranges
|
|
|
|
def get_reduction_type(self):
|
|
return self.reduction_type
|
|
|
|
def store_reduction(self, output_name, indexer, vars, reduction_vars):
|
|
return ops.reduction(
|
|
output_name,
|
|
self.dtype,
|
|
self.src_dtype,
|
|
self.reduction_type,
|
|
indexer(vars),
|
|
self.inner_fn(vars, reduction_vars),
|
|
)
|
|
|
|
def index_length(self):
|
|
return len(self.ranges) + len(self.reduction_ranges)
|
|
|
|
@cache_on_self
|
|
def inner_fn_str(self):
|
|
index = self._index(self.ranges)
|
|
rindex = self._index(self.reduction_ranges, "r")
|
|
return V.KernelFormatterHandler.ir_to_string(
|
|
self.inner_fn,
|
|
index,
|
|
rindex,
|
|
)
|
|
|
|
def constant_to_device(self, device):
|
|
"""Move this to a given device. Requires that all reads are to constants."""
|
|
loader = self.make_loader()
|
|
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
|
|
return Reduction(
|
|
device,
|
|
self.dtype,
|
|
loader,
|
|
self.ranges,
|
|
self.reduction_ranges,
|
|
self.reduction_type,
|
|
self.src_dtype,
|
|
ReductionHint.DEFAULT,
|
|
)
|
|
|
|
@staticmethod
|
|
def num_splits(
|
|
device,
|
|
dst_dtype,
|
|
src_dtype,
|
|
inner_fn,
|
|
ranges,
|
|
reduction_ranges,
|
|
reduction_type,
|
|
reduction_numel,
|
|
):
|
|
num_sm = get_device_properties(device).multi_processor_count
|
|
min_elements_per_thread = 32
|
|
max_elements_per_thread = 512
|
|
threads_per_sm = 2048
|
|
min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
|
|
max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
|
|
|
|
def inner_reduction_splits(reduction_numel_hint, numel_hint):
|
|
# do heuristics that's close to eager mode for split inner reduction
|
|
# we leak reduction autotune configs here, and will need to refactor to avoid this later
|
|
num_warps = 8
|
|
num_threads = 32 * num_warps
|
|
if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
|
|
return 1
|
|
if reduction_numel_hint <= 8192:
|
|
return 1
|
|
if reduction_numel_hint * numel_hint <= min_elements_per_device:
|
|
split_size = min_elements_per_thread
|
|
elif reduction_numel_hint * numel_hint < max_elements_per_device:
|
|
target_blocks = num_sm * threads_per_sm // (2 * num_threads)
|
|
blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
|
|
tmp_split_size = (
|
|
reduction_numel_hint + num_threads * blocks_per_output - 1
|
|
) // (num_threads * blocks_per_output)
|
|
divisors = sympy.divisors(reduction_numel_hint)
|
|
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
|
|
if abs(closest - tmp_split_size) < 30:
|
|
# prefer even splits, but never smalle than min_elements_per_thread
|
|
split_size = max(closest, min_elements_per_thread)
|
|
else:
|
|
split_size = tmp_split_size
|
|
else:
|
|
divisors = sympy.divisors(reduction_numel_hint)
|
|
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
|
|
if abs(closest - max_elements_per_thread) < 50:
|
|
# prefer even splits
|
|
split_size = closest
|
|
else:
|
|
split_size = max_elements_per_thread
|
|
return (reduction_numel_hint + split_size * num_threads - 1) // (
|
|
split_size * num_threads
|
|
)
|
|
|
|
def outer_reduction_splits(reduction_numel_hint, numel_hint):
|
|
# TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
|
|
# extend to even smaller number of outputs
|
|
num_warps = 8
|
|
num_threads = num_warps * 32
|
|
rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
|
|
xvals_per_block = 128
|
|
xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
|
|
if reduction_numel_hint * numel_hint < min_elements_per_device:
|
|
split_size = min_elements_per_thread
|
|
elif reduction_numel_hint * numel_hint < max_elements_per_device:
|
|
target_blocks = num_sm * threads_per_sm // (num_threads)
|
|
target_blocks = (target_blocks + xblocks - 1) // xblocks
|
|
tmp_split_size = (
|
|
reduction_numel_hint + rvals_per_thread * target_blocks - 1
|
|
) // (rvals_per_thread * target_blocks)
|
|
divisors = sympy.divisors(reduction_numel_hint)
|
|
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
|
|
if abs(tmp_split_size - closest) < 20:
|
|
split_size = max(closest, min_elements_per_thread)
|
|
else:
|
|
split_size = tmp_split_size
|
|
else:
|
|
divisors = sympy.divisors(reduction_numel_hint)
|
|
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
|
|
if abs(closest - max_elements_per_thread) < 50:
|
|
# prefer even splits
|
|
split_size = closest
|
|
else:
|
|
split_size = max_elements_per_thread
|
|
|
|
return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
|
|
rvals_per_thread * split_size
|
|
)
|
|
|
|
reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
|
|
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
|
|
# easy cases
|
|
if numel_hint == 1:
|
|
return ReductionHint.INNER, inner_reduction_splits(
|
|
reduction_numel_hint, numel_hint
|
|
)
|
|
if (
|
|
reduction_numel_hint <= min_elements_per_thread
|
|
or numel_hint >= num_sm * 2 * 32
|
|
):
|
|
return ReductionHint.DEFAULT, 1
|
|
|
|
r = Reduction(
|
|
device,
|
|
dst_dtype,
|
|
inner_fn,
|
|
ranges,
|
|
reduction_ranges,
|
|
reduction_type,
|
|
src_dtype,
|
|
ReductionHint.DEFAULT,
|
|
)
|
|
|
|
def get_read_indices(r):
|
|
cb = ComputedBuffer(
|
|
name=None,
|
|
layout=FlexibleLayout(
|
|
device=r.get_device(),
|
|
dtype=r.get_dtype(),
|
|
size=r.get_size(),
|
|
),
|
|
data=r,
|
|
)
|
|
read_writes = cb.get_read_writes()
|
|
# try finding the full size producer
|
|
# TODO this will fail for something like ((1, N) * (N, 1)).sum()
|
|
# this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
|
|
range_vars = [
|
|
r
|
|
for r in read_writes.range_vars
|
|
if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
|
|
]
|
|
indices = []
|
|
changed = False
|
|
for md in sorted(read_writes.reads, key=lambda x: x.name):
|
|
if all([r in md.index.free_symbols for r in range_vars]):
|
|
indices.append(md.index)
|
|
if md.name in V.graph.name_to_buffer:
|
|
buf = V.graph.name_to_buffer[md.name]
|
|
original_stride = buf.layout.stride
|
|
buf.decide_layout()
|
|
if buf.layout.stride != original_stride:
|
|
changed = True
|
|
return indices, changed
|
|
|
|
indices, changed = get_read_indices(r)
|
|
if changed:
|
|
indices, _ = get_read_indices(r)
|
|
|
|
if len(indices) == 0:
|
|
# TODO determine splits when all inputs are broadcast
|
|
return ReductionHint.DEFAULT, 1
|
|
|
|
_, (_, reduction_vars), _ = dependencies.index_vars_squeeze(
|
|
r.get_size(), r.get_reduction_size()
|
|
)
|
|
num_outer = 0
|
|
num_inner = 0
|
|
for i in indices:
|
|
strides = V.graph.sizevars.stride_hints(i, reduction_vars)
|
|
outer = all([s > 1 for s in strides])
|
|
if outer:
|
|
num_outer += 1
|
|
else:
|
|
num_inner += 1
|
|
if num_inner > num_outer:
|
|
return ReductionHint.INNER, inner_reduction_splits(
|
|
reduction_numel_hint, numel_hint
|
|
)
|
|
else:
|
|
return ReductionHint.OUTER, outer_reduction_splits(
|
|
reduction_numel_hint, numel_hint
|
|
)
|
|
|
|
@staticmethod
|
|
def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type):
|
|
"""Convert inner_fn from a reduction to an pointwise"""
|
|
reduction_ranges = [
|
|
V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges
|
|
]
|
|
|
|
if reduction_type == "sum":
|
|
|
|
def combine_fn(a, b):
|
|
return ops.add(a, b)
|
|
|
|
elif reduction_type == "min":
|
|
|
|
def combine_fn(a, b):
|
|
return ops.minimum(a, b)
|
|
|
|
elif reduction_type == "max":
|
|
|
|
def combine_fn(a, b):
|
|
return ops.maximum(a, b)
|
|
|
|
elif reduction_type == "any":
|
|
|
|
def combine_fn(a, b):
|
|
return ops.logical_or(a, b)
|
|
|
|
elif reduction_type == "argmin":
|
|
|
|
def combine_fn(a, b):
|
|
return ops.minimum(a[0], b[0]), ops.where(
|
|
ops.lt(b[0], a[0]), b[1], a[1]
|
|
)
|
|
|
|
elif reduction_type == "argmax":
|
|
|
|
def combine_fn(a, b):
|
|
return ops.maximum(a[0], b[0]), ops.where(
|
|
ops.gt(b[0], a[0]), b[1], a[1]
|
|
)
|
|
|
|
else:
|
|
raise NotImplementedError(f"unknown reduction_type={reduction_type}")
|
|
|
|
def fn(index):
|
|
return functools.reduce(
|
|
combine_fn,
|
|
(
|
|
value_fn(index, rindex)
|
|
for rindex in itertools.product(
|
|
*[range(x) for x in reduction_ranges]
|
|
)
|
|
),
|
|
)
|
|
|
|
if reduction_type in ("argmin", "argmax"):
|
|
flatten_index = FixedLayout(
|
|
None,
|
|
None,
|
|
reduction_ranges,
|
|
FlexibleLayout.contiguous_strides(reduction_ranges),
|
|
).make_indexer()
|
|
|
|
def value_fn(index, rindex):
|
|
rindex = [sympy.expand(i) for i in rindex]
|
|
return (
|
|
inner_fn(index, rindex),
|
|
ops.index_expr(flatten_index(rindex), torch.int64),
|
|
)
|
|
|
|
return lambda index: fn(index)[1]
|
|
else:
|
|
value_fn = inner_fn
|
|
return fn
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
device: torch.device,
|
|
dst_dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
inner_fn: Callable,
|
|
ranges: List[Expr],
|
|
reduction_ranges: List[Expr],
|
|
reduction_type: str,
|
|
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
|
):
|
|
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
|
|
|
|
if reduction_numel == 0:
|
|
# N.B. This is a hack to generate the literal of the given type
|
|
# Ideally, we should be fixing `def constant` in triton.py
|
|
# but it breaks due to hardcoded dtypes in other places
|
|
def py_cnst(val):
|
|
return (
|
|
bool(val)
|
|
if dst_dtype == torch.bool
|
|
else float(val)
|
|
if dst_dtype.is_floating_point
|
|
else int(val)
|
|
)
|
|
|
|
rtypes_to_inits = {
|
|
"sum": py_cnst(0),
|
|
"prod": py_cnst(1),
|
|
"any": py_cnst(0),
|
|
# "all" is desugared to `!any(!val)`
|
|
}
|
|
|
|
assert (
|
|
reduction_type in rtypes_to_inits.keys()
|
|
), f"{reduction_type} not supported for zero-dimension tensors!"
|
|
|
|
def const_fn(index):
|
|
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
|
|
|
|
return Pointwise.create(
|
|
device=device,
|
|
dtype=src_dtype,
|
|
inner_fn=const_fn,
|
|
ranges=list(ranges),
|
|
)
|
|
|
|
if reduction_numel == 1:
|
|
# this reduction is actually a pointwise op
|
|
if reduction_type in ("argmin", "argmax"):
|
|
|
|
def fn(index):
|
|
return ops.constant(0, dst_dtype)
|
|
|
|
else:
|
|
|
|
def fn(index):
|
|
reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
|
|
return inner_fn(index, reduction_index)
|
|
|
|
return Pointwise.create(device, dst_dtype, fn, ranges)
|
|
|
|
if (
|
|
isinstance(reduction_numel, sympy.Integer)
|
|
and V.graph.sizevars.size_hint(reduction_numel)
|
|
< config.unroll_reductions_threshold
|
|
and sympy_product(ranges) != 1
|
|
):
|
|
return Pointwise.create(
|
|
device,
|
|
dst_dtype,
|
|
cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type),
|
|
ranges,
|
|
)
|
|
|
|
split_reduction = (
|
|
is_triton(device)
|
|
and reduction_type
|
|
not in {
|
|
"argmax",
|
|
"argmin",
|
|
}
|
|
and config.split_reductions
|
|
)
|
|
if split_reduction and not dynamo_config.dynamic_shapes:
|
|
# triton doesn't support reduce to single element well, so break it up
|
|
hint, split = cls.num_splits(
|
|
device,
|
|
dst_dtype,
|
|
src_dtype,
|
|
inner_fn,
|
|
ranges,
|
|
reduction_ranges,
|
|
reduction_type,
|
|
reduction_numel,
|
|
)
|
|
# intermediate reduction in split can contain complex indexing,
|
|
# and num_splits will fail to correctly set the hint
|
|
# reuse the passed hint if available
|
|
if reduction_hint == ReductionHint.DEFAULT:
|
|
reduction_hint = hint
|
|
if split > 1:
|
|
# triton doesn't support reduce to single element well, so break it up
|
|
return cls.create_multilayer(
|
|
device,
|
|
dst_dtype,
|
|
src_dtype,
|
|
inner_fn,
|
|
ranges,
|
|
reduction_ranges,
|
|
reduction_type,
|
|
split,
|
|
reduction_hint,
|
|
)
|
|
elif split_reduction and dynamo_config.dynamic_shapes:
|
|
torch._logging.warning_once(
|
|
log,
|
|
"Could not do split reduction due to dynamic shapes; performance may be worse",
|
|
)
|
|
|
|
return TensorBox.create(
|
|
Reduction(
|
|
device,
|
|
dst_dtype,
|
|
inner_fn,
|
|
ranges,
|
|
reduction_ranges,
|
|
reduction_type,
|
|
src_dtype,
|
|
reduction_hint,
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def default_value(reduction_type, dtype):
|
|
if reduction_type in {"max", "argmax"}:
|
|
if is_float_dtype(dtype):
|
|
return float("-inf")
|
|
elif is_boolean_dtype(dtype):
|
|
return 0
|
|
else:
|
|
return torch.iinfo(dtype).min
|
|
if reduction_type in {"min", "argmin"}:
|
|
if is_float_dtype(dtype):
|
|
return float("inf")
|
|
elif is_boolean_dtype(dtype):
|
|
return 1
|
|
else:
|
|
return torch.iinfo(dtype).max
|
|
|
|
return {
|
|
"sum": 0,
|
|
"any": 0,
|
|
}[reduction_type]
|
|
|
|
@classmethod
|
|
def create_multilayer(
|
|
cls,
|
|
device: torch.device,
|
|
dst_dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
inner_fn: Callable,
|
|
ranges: List[Expr],
|
|
reduction_ranges: List[Expr],
|
|
reduction_type: str,
|
|
split: int,
|
|
reduction_hint: ReductionHint,
|
|
):
|
|
"""
|
|
Break a large reduction up into multiple smaller reductions
|
|
recursively
|
|
"""
|
|
reduction_numel = sympy_product(reduction_ranges)
|
|
|
|
# TODO(jansel): convert this to dynamic shapes
|
|
# TODO(jansel): realize the reduction so we can do dynamic indexing
|
|
reduction_ranges = [
|
|
sympy.Integer(V.graph.sizevars.guard_static_shape(s))
|
|
for s in reduction_ranges
|
|
]
|
|
reduction_numel = sympy.Integer(
|
|
V.graph.sizevars.guard_static_shape(reduction_numel)
|
|
)
|
|
|
|
if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
|
|
need_mask = False
|
|
else:
|
|
need_mask = True
|
|
|
|
split = sympy.Integer(split)
|
|
block_size = FloorDiv(reduction_numel + (split - 1), split)
|
|
|
|
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
|
|
|
|
def wrapper_fn(index, reduction_index):
|
|
(reduction_index,) = reduction_index
|
|
*new_index, reduction_block = index
|
|
indices = block_size * reduction_block + reduction_index
|
|
|
|
def body():
|
|
return inner_fn(new_index, reindex([indices]))
|
|
|
|
if need_mask:
|
|
mask = ops.lt(
|
|
ops.index_expr(indices, torch.int32),
|
|
ops.index_expr(reduction_numel, torch.int32),
|
|
)
|
|
return ops.masked(
|
|
mask, body, cls.default_value(reduction_type, dst_dtype)
|
|
)
|
|
else:
|
|
return body()
|
|
|
|
# triton will automatically compute reductions in fp32 if reducing over fp16/bf16
|
|
# within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
|
|
# in fp32 and not reduce precision by breaking up the kernel into multiple layers
|
|
intermediate_dtype = (
|
|
dst_dtype
|
|
if dst_dtype not in (torch.float16, torch.bfloat16)
|
|
else torch.float
|
|
)
|
|
intermediate = Reduction.create(
|
|
device,
|
|
intermediate_dtype,
|
|
src_dtype,
|
|
wrapper_fn,
|
|
[*ranges, split],
|
|
[block_size],
|
|
reduction_type,
|
|
reduction_hint,
|
|
)
|
|
intermediate.realize()
|
|
intermediate_loader = intermediate.make_loader()
|
|
|
|
def intermediate_fn(index, reduction_index):
|
|
return intermediate_loader([*index, *reduction_index])
|
|
|
|
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
|
|
if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
|
|
reduction_hint = ReductionHint.OUTER_TINY
|
|
if (
|
|
split <= 1024
|
|
and numel_hint <= 256
|
|
and reduction_hint == ReductionHint.OUTER
|
|
):
|
|
reduction_hint = ReductionHint.OUTER_TINY
|
|
return TensorBox.create(
|
|
Reduction(
|
|
device,
|
|
dst_dtype,
|
|
intermediate_fn,
|
|
ranges,
|
|
[split],
|
|
reduction_type,
|
|
src_dtype,
|
|
reduction_hint,
|
|
)
|
|
)
|
|
|
|
|
|
def is_storage_and_layout(x):
|
|
try:
|
|
as_storage_and_layout(x, freeze=False)
|
|
return True
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
|
|
def is_contiguous_storage_and_layout(x):
|
|
try:
|
|
buffer, layout = as_storage_and_layout(x, freeze=False)
|
|
return layout.is_contiguous()
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
|
|
def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
|
|
"""Try to simplify x into a StorageBox and a Layout"""
|
|
if isinstance(x, TensorBox):
|
|
return as_storage_and_layout(
|
|
x.data,
|
|
freeze=freeze,
|
|
want_contiguous=want_contiguous,
|
|
stride_order=stride_order,
|
|
)
|
|
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
|
|
if freeze:
|
|
if want_contiguous:
|
|
x.data.freeze_layout()
|
|
assert x.data.layout.is_contiguous()
|
|
elif stride_order is not None:
|
|
x.data.freeze_layout_with_stride_order(stride_order)
|
|
else:
|
|
x.data.decide_layout()
|
|
return x, x.data.layout
|
|
if isinstance(x, ReinterpretView):
|
|
# making the base of x contiguous or stride_ordered will not necessarily make
|
|
# the ReinterpretedView either, so dont pass along those arguments
|
|
buffer, _ = as_storage_and_layout(
|
|
x.data,
|
|
freeze=freeze,
|
|
)
|
|
return buffer, x.layout
|
|
raise NotImplementedError
|
|
|
|
|
|
as_contiguous_storage_and_layout = functools.partial(
|
|
as_storage_and_layout, want_contiguous=True
|
|
)
|
|
|
|
|
|
def is_stride_order_storage_and_layout(x, stride_order):
|
|
try:
|
|
buffer, layout = as_storage_and_layout(x, freeze=False)
|
|
return layout.is_stride_ordered(stride_order)
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BaseView(IRNode):
|
|
data: IRNode
|
|
|
|
def get_dtype(self):
|
|
return self.data.get_dtype()
|
|
|
|
def get_device(self):
|
|
return self.data.get_device()
|
|
|
|
def get_name(self):
|
|
return self.data.get_name()
|
|
|
|
def mark_reuse(self, users):
|
|
return self.data.mark_reuse(users)
|
|
|
|
def has_exceeded_max_reads(self):
|
|
return self.data.has_exceeded_max_reads()
|
|
|
|
def realize(self):
|
|
return self.data.realize()
|
|
|
|
def realize_hint(self):
|
|
return self.data.realize_hint()
|
|
|
|
def get_storage_numel(self):
|
|
return self.data.get_storage_numel()
|
|
|
|
def is_extern(self):
|
|
return self.data.is_extern()
|
|
|
|
@cache_on_self
|
|
def get_reads(self):
|
|
with patch.object(FlexibleLayout, "allow_indexing", True):
|
|
return extract_read_writes(
|
|
self.make_loader(),
|
|
self.get_size(),
|
|
).reads
|
|
|
|
def unwrap_view(self):
|
|
x = self
|
|
while isinstance(x, BaseView):
|
|
x = x.data
|
|
return x
|
|
|
|
def constant_to_device(self, device):
|
|
"""Move this to a given device. Requires that all reads are to constants."""
|
|
loader = self.make_loader()
|
|
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
|
|
return Pointwise(device, self.get_dtype(), loader, self.get_size())
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExpandView(BaseView):
|
|
size: List[Expr]
|
|
|
|
@staticmethod
|
|
def _normalize_size(x, new_size):
|
|
"""Replace `-1` with correct sizes"""
|
|
new_size = list(map(sympy.expand, new_size))
|
|
old_size = x.get_size()
|
|
old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
|
|
assert len(new_size) == len(old_size)
|
|
for i in range(len(new_size)):
|
|
if new_size[i] == -1:
|
|
assert old_size[i] is not None
|
|
new_size[i] = old_size[i]
|
|
return new_size
|
|
|
|
@classmethod
|
|
def create(cls, x, new_size):
|
|
new_size = cls._normalize_size(x, new_size)
|
|
|
|
if is_storage_and_layout(x):
|
|
storage, old_layout = as_storage_and_layout(x)
|
|
skip = len(new_size) - len(old_layout.size)
|
|
assert skip >= 0
|
|
new_stride = [sympy.Integer(0)] * skip
|
|
for stride, size in zip(old_layout.stride, old_layout.size):
|
|
new_stride.append(stride if size != 1 else sympy.Integer(0))
|
|
new_layout = FixedLayout(
|
|
old_layout.device,
|
|
old_layout.dtype,
|
|
list(new_size),
|
|
new_stride,
|
|
old_layout.offset,
|
|
)
|
|
return ReinterpretView(storage, new_layout)
|
|
|
|
return ExpandView(x, new_size)
|
|
|
|
def get_size(self):
|
|
return self.size
|
|
|
|
def make_loader(self):
|
|
target = self.get_size()
|
|
actual = self.data.get_size()
|
|
skip = len(target) - len(actual)
|
|
inner = self.data.make_loader()
|
|
|
|
def load(index):
|
|
index = list(index[skip:])
|
|
assert len(index) == len(actual)
|
|
for i in range(len(actual)):
|
|
if actual[i] == 1:
|
|
# zero out broadcast dimension
|
|
index[i] = sympy.Integer(0)
|
|
return inner(index)
|
|
|
|
return load
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PermuteView(BaseView):
|
|
dims: List[Expr]
|
|
|
|
@classmethod
|
|
def create(cls, x, dims):
|
|
dims = cls._map_neg_dims(dims)
|
|
assert set(dims) == set(range(len(dims)))
|
|
|
|
if is_storage_and_layout(x):
|
|
storage, old_layout = as_storage_and_layout(x)
|
|
new_layout = FixedLayout(
|
|
old_layout.device,
|
|
old_layout.dtype,
|
|
[old_layout.size[i] for i in dims],
|
|
[old_layout.stride[i] for i in dims],
|
|
old_layout.offset,
|
|
)
|
|
return ReinterpretView(storage, new_layout)
|
|
|
|
return PermuteView(x, dims)
|
|
|
|
@classmethod
|
|
def _map_neg_dims(cls, dims):
|
|
return [dim if dim >= 0 else len(dims) + dim for dim in dims]
|
|
|
|
def get_size(self):
|
|
assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
|
|
size = self.data.get_size()
|
|
return [size[i] for i in self.dims]
|
|
|
|
def make_loader(self):
|
|
inner = self.data.make_loader()
|
|
inv = {j: i for i, j in enumerate(self.dims)}
|
|
inv = [inv[i] for i in range(len(self.dims))]
|
|
assert set(inv) == set(range(len(self.dims)))
|
|
|
|
def load(index):
|
|
index = [index[i] for i in inv]
|
|
return inner(index)
|
|
|
|
return load
|
|
|
|
|
|
class SqueezeView(BaseView):
|
|
@classmethod
|
|
def create(cls, x, *, dim=None):
|
|
if is_storage_and_layout(x):
|
|
storage, old_layout = as_storage_and_layout(x)
|
|
new_size = []
|
|
new_stride = []
|
|
if dim is not None:
|
|
assert isinstance(dim, int), "expected integer dim argument"
|
|
assert 0 <= dim and dim < len(old_layout.size)
|
|
|
|
for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
|
|
if dim is None:
|
|
if size != 1:
|
|
new_size.append(size)
|
|
new_stride.append(stride)
|
|
else:
|
|
if i != dim:
|
|
new_size.append(size)
|
|
new_stride.append(stride)
|
|
else:
|
|
assert size == 1, "expected squeezed size to be 1"
|
|
|
|
new_layout = FixedLayout(
|
|
old_layout.device,
|
|
old_layout.dtype,
|
|
new_size,
|
|
new_stride,
|
|
old_layout.offset,
|
|
)
|
|
return ReinterpretView(storage, new_layout)
|
|
|
|
if dim is None:
|
|
# redirect to a generic view
|
|
return View.create(x, [s for s in x.get_size() if s != 1])
|
|
else:
|
|
assert x.get_size()[dim] == 1
|
|
return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
|
|
|
|
@staticmethod
|
|
def squeezer(size: Tuple[sympy.Expr, ...]):
|
|
new_size = [s for s in size if s != 1]
|
|
not_one = [i for i, s in enumerate(size) if s != 1]
|
|
length = len(size)
|
|
|
|
def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]:
|
|
assert len(index) == len(not_one), f"{index} {not_one}"
|
|
new_index = [sympy.Integer(0)] * length
|
|
for idx, s in zip(not_one, index):
|
|
new_index[idx] = s
|
|
return tuple(new_index)
|
|
|
|
return new_size, reindex
|
|
|
|
def __init__(self, data):
|
|
raise AssertionError("use SqueezeView.create()")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class View(BaseView):
|
|
size: List[Expr]
|
|
reindex: Callable
|
|
|
|
def make_indexer(self):
|
|
base_indexer = self.data.make_indexer()
|
|
|
|
def indexer(idx):
|
|
return base_indexer(self.reindex(idx))
|
|
|
|
return indexer
|
|
|
|
@staticmethod
|
|
def handle_negative_index(idx, size):
|
|
idx = sympy.expand(idx)
|
|
size = sympy.expand(size)
|
|
sizevars = V.graph.sizevars
|
|
if sizevars.size_hint(idx) < 0:
|
|
sizevars.guard_lt(idx, 0)
|
|
idx = idx + size
|
|
return idx
|
|
|
|
def reindex_str(self):
|
|
index_old = [sympy_symbol(f"i{n}") for n in range(len(self.size))]
|
|
index_new = list(self.reindex(index_old))
|
|
return f"lambda {', '.join(map(str, index_old))}: {index_new}"
|
|
|
|
def __str__(self):
|
|
return self.str_helper(
|
|
[self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
|
|
)
|
|
|
|
__repr__ = __str__
|
|
|
|
@classmethod
|
|
def create(cls, x, new_size):
|
|
assert isinstance(new_size, (tuple, list))
|
|
old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
|
|
|
|
# Skip pointless views
|
|
if V.graph.sizevars.maybe_guard_list_equals(old_size, new_size):
|
|
return x
|
|
|
|
# TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
|
|
if is_contiguous_storage_and_layout(x) and not isinstance(
|
|
x.data, ExternKernelAlloc
|
|
):
|
|
storage, old_layout = as_contiguous_storage_and_layout(x)
|
|
new_layout = FixedLayout(
|
|
old_layout.device,
|
|
old_layout.dtype,
|
|
new_size,
|
|
FlexibleLayout.contiguous_strides(new_size),
|
|
old_layout.offset,
|
|
)
|
|
return ReinterpretView(storage, new_layout)
|
|
|
|
reindex = cls.dynamic_reshape_indexer(old_size, new_size)
|
|
return cls(x, tuple(new_size), reindex)
|
|
|
|
@staticmethod
|
|
def resolve_negative_size(old_size, new_size):
|
|
new_size = [V.graph.sizevars.simplify(x) for x in new_size]
|
|
old_size = [V.graph.sizevars.simplify(x) for x in old_size]
|
|
|
|
new_size = list(new_size)
|
|
for i in range(len(new_size)):
|
|
if new_size[i] == -1:
|
|
new_size[i] = sympy.Integer(1)
|
|
new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
|
|
break
|
|
|
|
V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
|
|
return old_size, new_size
|
|
|
|
@classmethod
|
|
def dynamic_reshape_indexer(cls, old_size, new_size):
|
|
try:
|
|
reindex = cls._dynamic_reshape_indexer(old_size, new_size)
|
|
except (AssertionError, IndexError):
|
|
# optimistic algorithm failed, lets do a fallback
|
|
flat = [sympy_product(old_size)]
|
|
reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
|
|
reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
|
|
reindex = fuse_reindexing(reindex1, reindex2)
|
|
return reindex
|
|
|
|
@staticmethod
|
|
def _dynamic_reshape_indexer(old_size, new_size):
|
|
"""
|
|
Perform a reshape entirely by modifying indexing math
|
|
"""
|
|
size_hint = V.graph.sizevars.size_hint
|
|
vars = [sympy_symbol(f"view{i}") for i in range(len(new_size))]
|
|
|
|
stack_new = list(zip(vars, new_size))
|
|
stack_old = list(old_size)
|
|
|
|
view_expr = []
|
|
while stack_new and stack_old:
|
|
size_old = stack_old.pop()
|
|
var, size_new = stack_new.pop()
|
|
if size_old == 1:
|
|
view_expr.append(sympy.Integer(0))
|
|
stack_new.append((var, size_new)) # re-add
|
|
elif size_new == 1:
|
|
stack_old.append(size_old) # re-add
|
|
elif size_hint(size_new) == size_hint(size_old):
|
|
view_expr.append(var)
|
|
V.graph.sizevars.guard_equals(size_new, size_old)
|
|
elif size_hint(size_new) < size_hint(size_old):
|
|
while size_hint(size_new) < size_hint(size_old):
|
|
var2, size_new2 = stack_new.pop()
|
|
var = var2 * size_new + var
|
|
size_new = size_new * size_new2
|
|
view_expr.append(var)
|
|
V.graph.sizevars.guard_equals(size_new, size_old)
|
|
elif size_hint(size_new) > size_hint(size_old):
|
|
divisor = sympy.Integer(1)
|
|
modulus = size_old
|
|
view_expr.append(ModularIndexing(var, divisor, modulus))
|
|
divisor = divisor * modulus
|
|
while size_hint(size_new) > size_hint(size_old):
|
|
modulus = stack_old.pop()
|
|
view_expr.append(ModularIndexing(var, divisor, modulus))
|
|
divisor = divisor * modulus
|
|
size_old = size_old * modulus
|
|
V.graph.sizevars.guard_equals(size_new, size_old)
|
|
else:
|
|
raise AssertionError()
|
|
|
|
while stack_old:
|
|
size_old = stack_old.pop()
|
|
V.graph.sizevars.guard_equals(size_old, 1)
|
|
view_expr.append(sympy.Integer(0))
|
|
|
|
while stack_new:
|
|
var, size_new = stack_new.pop()
|
|
V.graph.sizevars.guard_equals(size_new, 1)
|
|
|
|
view_expr = list(reversed(view_expr))
|
|
assert len(view_expr) == len(old_size)
|
|
|
|
def reindex(index):
|
|
assert len(index) == len(vars), (len(index), len(vars))
|
|
replacements = dict(zip(vars, index))
|
|
return tuple(sympy_subs(x, replacements) for x in view_expr)
|
|
|
|
return reindex
|
|
|
|
def get_size(self):
|
|
return self.size
|
|
|
|
def make_loader(self):
|
|
def load(index):
|
|
return inner(self.reindex(index))
|
|
|
|
inner = self.data.make_loader()
|
|
return load
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ReinterpretView(BaseView):
|
|
"""Pretend our storage has a different layout"""
|
|
|
|
layout: "Layout"
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.data, BaseView):
|
|
self.data = self.data.unwrap_view()
|
|
|
|
def __str__(self):
|
|
return self.str_helper(
|
|
[
|
|
self.data,
|
|
self.layout,
|
|
]
|
|
)
|
|
|
|
__repr__ = __str__
|
|
|
|
def get_name(self):
|
|
return self.data.get_name()
|
|
|
|
def get_device(self):
|
|
return self.layout.device
|
|
|
|
def get_dtype(self):
|
|
return self.layout.dtype
|
|
|
|
def get_size(self):
|
|
return list(self.layout.size)
|
|
|
|
def get_stride(self):
|
|
return list(self.layout.stride)
|
|
|
|
def make_loader(self):
|
|
def loader(index):
|
|
indexer = self.layout.make_indexer()
|
|
return ops.load(self.get_name(), indexer(index))
|
|
|
|
return loader
|
|
|
|
def make_indexer(self):
|
|
return self.layout.make_indexer()
|
|
|
|
def get_layout(self):
|
|
return self.layout
|
|
|
|
def freeze_layout(self):
|
|
pass
|
|
|
|
def codegen_reference(self):
|
|
size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size)
|
|
stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride)
|
|
offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset)
|
|
namespace = V.graph.wrapper_code.namespace
|
|
if offset != "0":
|
|
return (
|
|
f"{namespace}as_strided({self.get_name()}, {size}, {stride}, {offset})"
|
|
)
|
|
return f"{namespace}as_strided({self.get_name()}, {size}, {stride})"
|
|
|
|
|
|
class SliceView(View):
|
|
@classmethod
|
|
def create(cls, x, dim, start, end, step=1):
|
|
step = sympy.expand(step)
|
|
assert step > 0
|
|
try:
|
|
if start == 0 and end >= 2**63 and step == 1:
|
|
return x
|
|
except TypeError:
|
|
pass
|
|
|
|
sizevars = V.graph.sizevars
|
|
new_size = list(x.get_size())
|
|
|
|
start = cls.handle_negative_index(start, new_size[dim])
|
|
end = cls.handle_negative_index(end, new_size[dim])
|
|
|
|
end = sizevars.guard_min(end, new_size[dim])
|
|
start = sizevars.guard_min(sizevars.guard_min(start, new_size[dim]), end)
|
|
if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
|
|
sizevars.guard_equals(end, new_size[dim])
|
|
return x
|
|
|
|
new_size[dim] = FloorDiv(end - start + (step - 1), step)
|
|
|
|
if is_storage_and_layout(x):
|
|
# Fast path
|
|
storage, old_layout = as_storage_and_layout(x)
|
|
new_stride = list(old_layout.stride)
|
|
new_stride[dim] = new_stride[dim] * step
|
|
new_layout = FixedLayout(
|
|
old_layout.device,
|
|
old_layout.dtype,
|
|
new_size,
|
|
new_stride,
|
|
old_layout.offset + old_layout.stride[dim] * start,
|
|
)
|
|
return ReinterpretView(storage, new_layout)
|
|
|
|
def reindex(index):
|
|
assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
|
|
index = list(index)
|
|
index[dim] = index[dim] * step + start
|
|
return index
|
|
|
|
# redirect to a generic view
|
|
return SliceView(x, size=new_size, reindex=reindex)
|
|
|
|
|
|
class BaseConstant(IRNode):
|
|
def get_size(self):
|
|
return ()
|
|
|
|
def get_dtype(self):
|
|
return self.dtype
|
|
|
|
def get_device(self):
|
|
return self.device
|
|
|
|
def mark_reuse(self, users):
|
|
pass
|
|
|
|
def has_exceeded_max_reads(self):
|
|
return False
|
|
|
|
def get_reads(self):
|
|
return ()
|
|
|
|
def is_extern(self):
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Constant(BaseConstant):
|
|
value: Any
|
|
dtype: torch.dtype
|
|
device: torch.device
|
|
|
|
def make_loader(self):
|
|
def loader(index):
|
|
return ops.constant(self.value, self.dtype)
|
|
|
|
return loader
|
|
|
|
def realize(self):
|
|
pass
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class IndexingConstant(BaseConstant):
|
|
index: Any
|
|
dtype: torch.dtype
|
|
device: torch.device
|
|
|
|
def make_loader(self):
|
|
def loader(index):
|
|
return ops.index_expr(self.index, self.dtype)
|
|
|
|
return loader
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Layout(IRNode):
|
|
def __init__(
|
|
self,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
size: List[Expr],
|
|
stride: List[Expr],
|
|
offset: Expr = Integer(0),
|
|
):
|
|
assert stride is None or len(size) == len(
|
|
stride
|
|
), f"size={size}, stride={stride}"
|
|
self.device = device
|
|
self.dtype = dtype
|
|
assert all(isinstance(s, (Expr, int)) for s in size)
|
|
self.size = size
|
|
self._stride = stride
|
|
self.offset = offset
|
|
|
|
@property
|
|
def stride(self):
|
|
return self._stride
|
|
|
|
def __str__(self):
|
|
offset = ""
|
|
if self.offset != 0:
|
|
offset = f", offset={self.offset}"
|
|
return (
|
|
f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
|
|
f"size={self.size}, stride={self.stride}{offset})"
|
|
)
|
|
|
|
__repr__ = __str__
|
|
|
|
def is_contiguous(self):
|
|
for left, right, size in zip(
|
|
self.stride, FlexibleLayout.contiguous_strides(self.size), self.size
|
|
):
|
|
if size != 1 and left != right:
|
|
return False
|
|
return True
|
|
|
|
def is_channels_last_contiguous(self):
|
|
ndim = len(self.size)
|
|
if ndim not in [4, 5]:
|
|
return False
|
|
for left, right, size in zip(
|
|
self.stride, make_channels_last_strides_for(self.size), self.size
|
|
):
|
|
if size != 1 and left != right:
|
|
return False
|
|
return True
|
|
|
|
def is_transposed(self):
|
|
for left, right, size in zip(
|
|
self.stride,
|
|
reversed(FlexibleLayout.contiguous_strides(self.size)),
|
|
self.size,
|
|
):
|
|
if size != 1 and left != right:
|
|
return False
|
|
return True
|
|
|
|
def is_stride_ordered(self, order):
|
|
assert len(self.stride) == len(order)
|
|
# reorder the stride given order
|
|
stride_ordered = [None] * len(order)
|
|
for i in range(len(order)):
|
|
stride_ordered[order[i]] = V.graph.sizevars.size_hint(self.stride[i])
|
|
# check if it is in ascending order
|
|
for i in range(len(order) - 1):
|
|
if stride_ordered[i] > stride_ordered[i + 1]:
|
|
return False
|
|
return True
|
|
|
|
def is_channels_last_stride_ordered(self):
|
|
# create channels_last order(NCHW, NCDHW, the C is the first order).
|
|
order = [0] + list(reversed(range(1, len(self.stride) - 1)))
|
|
order = [len(order)] + order
|
|
return self.is_stride_ordered(order)
|
|
|
|
def as_fixed(self):
|
|
return FixedLayout(
|
|
self.device,
|
|
self.dtype,
|
|
self.size,
|
|
self.stride,
|
|
self.offset,
|
|
)
|
|
|
|
def make_indexer(self):
|
|
assert (
|
|
FlexibleLayout.allow_indexing
|
|
), f"convert {type(self).__name__} to FixedLayout first"
|
|
return self.as_fixed().make_indexer()
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return (
|
|
self.device == other.device
|
|
and self.dtype == other.dtype
|
|
and self.size == other.size
|
|
and self.stride == other.stride
|
|
and self.offset == other.offset
|
|
)
|
|
|
|
def storage_size(self) -> sympy.Expr:
|
|
return compute_required_storage_length(self.size, self.stride, self.offset)
|
|
|
|
|
|
class FixedLayout(Layout):
|
|
"""A Tensor layout we cannot change"""
|
|
|
|
def __init__(
|
|
self,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
size: List[Expr],
|
|
stride: List[Expr] = None,
|
|
offset: Expr = Integer(0),
|
|
):
|
|
if stride is None:
|
|
stride = FlexibleLayout.contiguous_strides(size)
|
|
super().__init__(
|
|
device,
|
|
dtype,
|
|
size,
|
|
stride,
|
|
offset,
|
|
)
|
|
|
|
def make_indexer(self):
|
|
"""A closure containing math to read a given element"""
|
|
|
|
def indexer(index):
|
|
assert len(index) == len(self.stride) == len(self.size)
|
|
result = self.offset
|
|
for idx, stride, sz in zip(index, self.stride, self.size):
|
|
if sz != 1:
|
|
result = result + idx * stride
|
|
return result
|
|
|
|
return indexer
|
|
|
|
|
|
class FlexibleLayout(Layout):
|
|
"""A Tensor layout we are allowed to change"""
|
|
|
|
allow_indexing = False
|
|
|
|
@staticmethod
|
|
def contiguous_strides(sizes):
|
|
if len(sizes) == 0:
|
|
return []
|
|
reversed_strides = [sympy.Integer(1)]
|
|
for size in reversed(sizes[1:]):
|
|
reversed_strides.append(size * reversed_strides[-1])
|
|
return list(reversed(reversed_strides))
|
|
|
|
@staticmethod
|
|
def fill_ordered(sizes, order):
|
|
"""
|
|
Create a stride based on the order the dimensions should be filled in.
|
|
|
|
In this format, channels last would be:
|
|
[1, 3, 2, 0]
|
|
"""
|
|
assert set(range(len(sizes))) == set(order)
|
|
next_stride = sympy.Integer(1)
|
|
strides = [None] * len(order)
|
|
|
|
for i in order:
|
|
strides[i] = next_stride
|
|
next_stride = next_stride * sizes[i]
|
|
return strides
|
|
|
|
@staticmethod
|
|
def stride_ordered(sizes, order):
|
|
"""
|
|
Create a stride based on the sorted order of a permuted range.
|
|
|
|
In this format, channels last would be:
|
|
[3, 0, 2, 1]
|
|
"""
|
|
assert set(range(len(sizes))) == set(order)
|
|
fill_order = stride_order2fill_order(order)
|
|
return FlexibleLayout.fill_ordered(sizes, fill_order)
|
|
|
|
@staticmethod
|
|
def same_ordered(sizes, stride):
|
|
"""
|
|
Create a stride that has the same stride order as given stride
|
|
|
|
For example, if given stride is [1000, 1, 100, 10],
|
|
the fill order should be [1, 3, 2, 0]
|
|
"""
|
|
assert len(sizes) == len(stride)
|
|
stride = [V.graph.sizevars.size_hint(x) for x in stride]
|
|
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
|
|
return FlexibleLayout.fill_ordered(sizes, fill_order)
|
|
|
|
def as_stride_order(self, order):
|
|
return FixedLayout(
|
|
self.device,
|
|
self.dtype,
|
|
self.size,
|
|
self.stride_ordered(self.size, order),
|
|
self.offset,
|
|
)
|
|
|
|
def as_fill_order(self, order):
|
|
return FixedLayout(
|
|
self.device,
|
|
self.dtype,
|
|
self.size,
|
|
self.fill_ordered(self.size, order),
|
|
self.offset,
|
|
)
|
|
|
|
def as_same_order(self, stride):
|
|
return FixedLayout(
|
|
self.device,
|
|
self.dtype,
|
|
self.size,
|
|
self.same_ordered(self.size, stride),
|
|
self.offset,
|
|
)
|
|
|
|
def __init__(self, device, dtype, size, stride_order=None):
|
|
if stride_order:
|
|
strides = FlexibleLayout.fill_ordered(size, stride_order)
|
|
else:
|
|
strides = FlexibleLayout.contiguous_strides(size)
|
|
super().__init__(device, dtype, size, strides)
|
|
|
|
|
|
class AliasedLayout(Layout):
|
|
"""Shares the same storage as another tensor"""
|
|
|
|
def __init__(self, view: "ReinterpretView"):
|
|
layout = view.get_layout()
|
|
super().__init__(
|
|
layout.device,
|
|
layout.dtype,
|
|
layout.size,
|
|
layout.stride,
|
|
)
|
|
self.view = view
|
|
|
|
def make_indexer(self):
|
|
return self.as_fixed().make_indexer()
|
|
|
|
def maybe_guard_aligned(self):
|
|
offset = self.view.get_layout().offset
|
|
if offset == 0:
|
|
return True
|
|
from .compile_fx import ALIGNMENT
|
|
|
|
return V.graph.sizevars.maybe_guard_multiple_of(offset, ALIGNMENT)
|
|
|
|
|
|
class MutationLayout(Layout):
|
|
def __init__(self, target: IRNode):
|
|
super().__init__(
|
|
target.get_device(),
|
|
target.get_dtype(),
|
|
target.get_size(),
|
|
None, # type: ignore[arg-type]
|
|
)
|
|
self.target = target
|
|
|
|
@Layout.stride.getter
|
|
def stride(self):
|
|
return self.real_layout().stride
|
|
|
|
def storage_size(self) -> sympy.Expr:
|
|
return self.real_layout().storage_size()
|
|
|
|
def real_layout(self):
|
|
def unwrap_views(target):
|
|
if isinstance(target, MutationLayout):
|
|
return unwrap_views(target.target)
|
|
if isinstance(target, BaseView):
|
|
return unwrap_views(target.unwrap_view())
|
|
if isinstance(target, MutableBox):
|
|
return unwrap_views(target.data)
|
|
return target
|
|
|
|
return unwrap_views(self.target).layout
|
|
|
|
@classmethod
|
|
def realize_into(cls, src, dst):
|
|
dst.realize()
|
|
V.graph.realize_users_of(dst.get_name())
|
|
|
|
if isinstance(src, TensorBox):
|
|
src = src.data
|
|
|
|
if not isinstance(src, StorageBox) or src.is_user_of(dst.get_name()):
|
|
need_copy = True
|
|
else:
|
|
src.realize()
|
|
need_copy = not isinstance(src.data.layout, FlexibleLayout)
|
|
|
|
if need_copy:
|
|
src = Pointwise.create(
|
|
device=src.get_device(),
|
|
dtype=src.get_dtype(),
|
|
inner_fn=src.make_loader(),
|
|
ranges=[
|
|
V.graph.sizevars.guard_equals(a, b)
|
|
for a, b in zip(src.get_size(), dst.get_size())
|
|
],
|
|
).data
|
|
src.realize()
|
|
|
|
assert isinstance(src.data.layout, FlexibleLayout)
|
|
src.data.layout = MutationLayout(dst)
|
|
return src.data
|
|
|
|
def as_fixed(self):
|
|
return self
|
|
|
|
def make_indexer(self):
|
|
return self.target.make_indexer()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Buffer(IRNode):
|
|
name: str
|
|
layout: Layout
|
|
|
|
def make_indexer(self):
|
|
return self.layout.make_indexer()
|
|
|
|
def get_name(self):
|
|
assert self.name
|
|
return self.name
|
|
|
|
def get_device(self):
|
|
return self.layout.device
|
|
|
|
def get_dtype(self):
|
|
return getattr(self.layout, "dtype", None)
|
|
|
|
def get_size(self):
|
|
return list(self.layout.size)
|
|
|
|
def get_stride(self):
|
|
return list(self.layout.stride)
|
|
|
|
def get_layout(self):
|
|
return self.layout
|
|
|
|
def get_storage_numel(self):
|
|
return self.get_numel()
|
|
|
|
def is_extern(self):
|
|
return False
|
|
|
|
def freeze_layout(self):
|
|
if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)):
|
|
self.layout = self.layout.as_fixed()
|
|
|
|
def freeze_layout_with_stride_order(self, order):
|
|
assert isinstance(self.layout, FlexibleLayout)
|
|
self.layout = self.layout.as_stride_order(order)
|
|
|
|
def freeze_layout_with_fill_order(self, order):
|
|
assert isinstance(self.layout, FlexibleLayout)
|
|
self.layout = self.layout.as_fill_order(order)
|
|
|
|
def freeze_layout_with_same_order(self, stride):
|
|
assert isinstance(self.layout, FlexibleLayout)
|
|
self.layout = self.layout.as_same_order(stride)
|
|
|
|
def make_loader(self):
|
|
def loader(index):
|
|
indexer = self.layout.make_indexer()
|
|
return ops.load(self.name, indexer(index))
|
|
|
|
return loader
|
|
|
|
def is_no_op(self):
|
|
return False
|
|
|
|
def codegen_reference(self):
|
|
return self.get_name()
|
|
|
|
def decide_layout(self):
|
|
pass
|
|
|
|
def get_alias_names(self):
|
|
if isinstance(self.layout, AliasedLayout):
|
|
return [self.layout.view.get_name()]
|
|
return ()
|
|
|
|
def get_mutation_names(self):
|
|
if isinstance(self.layout, MutationLayout):
|
|
return [self.layout.target.get_name()]
|
|
return ()
|
|
|
|
@cache_on_self
|
|
def get_read_writes(self):
|
|
with patch.object(FlexibleLayout, "allow_indexing", True):
|
|
return extract_read_writes(
|
|
self.make_loader(),
|
|
self.get_size(),
|
|
)
|
|
|
|
def get_reads(self):
|
|
return self.get_read_writes().reads
|
|
|
|
def realize(self):
|
|
pass
|
|
|
|
|
|
class InputBuffer(Buffer):
|
|
pass
|
|
|
|
|
|
class ConstantBuffer(InputBuffer):
|
|
override_device = None
|
|
|
|
def make_loader(self):
|
|
def loader(index):
|
|
indexer = self.layout.make_indexer()
|
|
return ops.load(
|
|
V.graph.constant_name(self.name, self.override_device), indexer(index)
|
|
)
|
|
|
|
return loader
|
|
|
|
def constant_to_device(self, device):
|
|
return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout)
|
|
|
|
|
|
class RandSeedBuffer(ConstantBuffer):
|
|
def codegen_reference(self):
|
|
# Clone makes sure if we pass this from forwards to backwards
|
|
# the value does not get clobbered by the time backwards is run.
|
|
return self.get_name() + ".clone()"
|
|
|
|
|
|
class NoneAsConstantBuffer(IRNode):
|
|
def codegen_reference(self):
|
|
return "None"
|
|
|
|
|
|
class ShapeAsConstantBuffer(IRNode):
|
|
def __init__(self, shape):
|
|
super().__init__()
|
|
self.shape = shape
|
|
|
|
def codegen_reference(self):
|
|
from torch._inductor.codegen.wrapper import pexpr
|
|
|
|
expr = pexpr(V.graph.sizevars.simplify(self.shape))
|
|
if V.graph.cpp_wrapper:
|
|
# wrap scalar to 0-d tensor for cpp wrapper
|
|
return f"torch::tensor({expr})"
|
|
else:
|
|
return expr
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ComputedBuffer(Buffer):
|
|
data: Loops
|
|
|
|
@cache_on_self
|
|
def get_read_writes(self):
|
|
with patch.object(FlexibleLayout, "allow_indexing", True):
|
|
if self.data.get_reduction_type():
|
|
return extract_read_writes(
|
|
self.get_store_function(),
|
|
self.data.get_size(),
|
|
self.data.get_reduction_size(),
|
|
)
|
|
else:
|
|
return extract_read_writes(
|
|
self.get_store_function(),
|
|
self.data.get_size(),
|
|
)
|
|
|
|
def get_store_function(self):
|
|
indexer = self.layout.as_fixed().make_indexer()
|
|
if self.data.get_reduction_type():
|
|
return partial(self.data.store_reduction, self.name, indexer)
|
|
else:
|
|
return partial(self.data.store_output, self.name, indexer)
|
|
|
|
def get_fill_order(self):
|
|
"""
|
|
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
|
|
|
|
TODO(jansel): A better algorithm here would look at downstream consumers of this
|
|
value and try to do global graph-level layout optimization.
|
|
This is also something just begging to be autotuned.
|
|
"""
|
|
if isinstance(self.layout, FlexibleLayout):
|
|
_, (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
|
|
self.data.get_size(), self.data.get_reduction_size()
|
|
)
|
|
reads = self.get_read_writes().reads
|
|
reads_bufs = [
|
|
V.graph.name_to_buffer[r.name]
|
|
if r.name in V.graph.name_to_buffer.keys()
|
|
else None
|
|
for r in reads
|
|
]
|
|
# only consider reads to buffer of same size
|
|
reads = [
|
|
sympy_subs(
|
|
r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
|
|
)
|
|
for r in reads
|
|
]
|
|
|
|
if reads:
|
|
stride_lengths = [
|
|
V.graph.sizevars.stride_hints(expr, index_vars) for expr in reads
|
|
]
|
|
from .scheduler import pick_loop_order
|
|
|
|
return pick_loop_order(stride_lengths, self.get_size())
|
|
|
|
return None
|
|
|
|
def decide_layout(self):
|
|
if isinstance(self.layout, FlexibleLayout):
|
|
order = self.get_fill_order()
|
|
if order:
|
|
self.freeze_layout_with_fill_order(order)
|
|
else:
|
|
self.freeze_layout()
|
|
|
|
def simplify_and_reorder(self):
|
|
"""
|
|
This is a main place where we do loop transformations in a
|
|
backend-agnostic way.
|
|
|
|
Here we:
|
|
1) Remove any 1 dimensions
|
|
2) Fuse contiguous dimensions together
|
|
3) Reorder dimensions based on stride orders
|
|
"""
|
|
_, args, var_ranges = dependencies.index_vars_squeeze(
|
|
self.data.get_size(), self.data.get_reduction_size(), prefix="q"
|
|
)
|
|
with patch.object(ConstantBuffer, "override_device", self.get_device()):
|
|
body = LoopBody(
|
|
self.get_store_function(),
|
|
(args if self.get_reduction_type() else args[:1]),
|
|
var_ranges,
|
|
)
|
|
index_formulas = [*body.indexing_exprs.values()]
|
|
reads_bufs = [
|
|
V.graph.name_to_buffer[reads_name]
|
|
if reads_name in V.graph.name_to_buffer.keys()
|
|
else None
|
|
for reads_name in body.reads_name2expr.keys()
|
|
]
|
|
memory_addrs = [
|
|
*body.reads_name2expr.values(),
|
|
*body.writes_name2expr.values(),
|
|
]
|
|
index_vars = []
|
|
reduce_vars = []
|
|
index_size = []
|
|
reduce_size = []
|
|
for v, s in var_ranges.items():
|
|
if v in args[0]:
|
|
assert not reduce_vars
|
|
index_vars.append(v)
|
|
index_size.append(s)
|
|
else:
|
|
assert v in args[1]
|
|
reduce_vars.append(v)
|
|
reduce_size.append(s)
|
|
|
|
# the reordering_reindex in reads' simplify_reorder_and_tile
|
|
reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
|
|
for i, reads_buf in enumerate(reads_bufs):
|
|
if isinstance(reads_buf, ComputedBuffer) and hasattr(
|
|
reads_buf, "iter_reordering_reindex"
|
|
):
|
|
reordering_reindex[i] = reads_buf.iter_reordering_reindex
|
|
|
|
def simplify_and_reorder(x_vars, sizes, reordering_reindex=None):
|
|
sizes, reindex0, reindex1 = self._apply_loop_reordering(
|
|
x_vars, sizes, memory_addrs, reordering_reindex
|
|
)
|
|
# for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
|
|
x_vars = reindex0(x_vars)
|
|
sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
|
|
x_vars,
|
|
sizes,
|
|
index_prevent_reordering(index_formulas, x_vars, sizes),
|
|
)
|
|
x_vars = prune(x_vars)
|
|
# sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
|
|
# x_vars = prune(x_vars)
|
|
# sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
|
|
reindex = fuse_reindexing(reindex1, reindex2)
|
|
return sizes, reindex, reindex1
|
|
|
|
iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
|
|
index_vars, index_size, reordering_reindex
|
|
)
|
|
reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
|
|
reduce_vars, reduce_size
|
|
)
|
|
|
|
# remember the reordering if not have loop collapse.
|
|
if len(iter_ranges) == len(index_vars):
|
|
self.iter_reordering_reindex = iter_reordering_reindex
|
|
# retrace the loop body with simplification and reordering applied
|
|
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
|
iter_ranges, reduce_ranges, prefix="z"
|
|
)
|
|
body = LoopBody(
|
|
body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
|
|
)
|
|
return (iter_ranges, reduce_ranges), body
|
|
|
|
@staticmethod
|
|
def _apply_loop_reordering(
|
|
index_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=None
|
|
):
|
|
"""
|
|
Shuffle the order of loops around to hopefully improve performance.
|
|
"""
|
|
from .scheduler import pick_loop_order
|
|
|
|
if priority_idx is None:
|
|
priority_idx = []
|
|
|
|
try:
|
|
strides = [
|
|
V.graph.sizevars.stride_hints(expr, index_vars) for expr in memory_addrs
|
|
]
|
|
assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
|
|
index_vars
|
|
)
|
|
# consider both layout(strides) and reordering(reordering_reindex)
|
|
if reordering_reindex is not None:
|
|
for i in range(len(memory_addrs)):
|
|
try:
|
|
strides[i] = reordering_reindex[i](strides[i])
|
|
# if len(order) != len(strides), do not reorder
|
|
except AssertionError:
|
|
pass
|
|
order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
|
|
except Exception:
|
|
if config.debug:
|
|
log.warning(
|
|
"Did not simplify complex index:\n%s\n%s",
|
|
dict(zip(index_vars, sizes)),
|
|
memory_addrs,
|
|
)
|
|
order = list(range(len(sizes)))
|
|
sizes = [sizes[i] for i in order]
|
|
return sizes, same_reorder(order), inverse_reorder(order)
|
|
|
|
def get_reduction_size(self):
|
|
return self.data.get_reduction_size()
|
|
|
|
def get_reduction_type(self):
|
|
return self.data.get_reduction_type()
|
|
|
|
def is_no_op(self):
|
|
return self.data.is_zero_elements()
|
|
|
|
def should_allocate(self):
|
|
return True
|
|
|
|
def constant_to_device(self, device):
|
|
"""Move this to a given device. Requires that all reads are to constants."""
|
|
return self.data.constant_to_device(device)
|
|
|
|
|
|
class TemplateBuffer(Buffer):
|
|
"""
|
|
Represents a Triton (in the future other type) of template operator
|
|
that we can fuse an epilogue onto.
|
|
"""
|
|
|
|
def __init__(self, layout, inputs, make_kernel_render):
|
|
super().__init__(name=None, layout=layout)
|
|
self.inputs = InputsKernel.unwrap_storage(inputs)
|
|
self.make_kernel_render = make_kernel_render
|
|
self.name = V.graph.register_buffer(self)
|
|
|
|
def get_read_writes(self):
|
|
return self.normalized_read_writes()
|
|
|
|
@cache_on_self
|
|
def normalized_read_writes(self):
|
|
name = self.get_name()
|
|
indexer = self.layout.make_indexer()
|
|
|
|
def dummy(index, rindex):
|
|
assert len(rindex) == 0
|
|
return ops.store(name, indexer(index), "fake")
|
|
|
|
deps = dependencies.extract_read_writes(
|
|
dummy, self.get_size(), (), normalize=True
|
|
)
|
|
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
|
|
return deps
|
|
|
|
def get_reduction_size(self):
|
|
return 1
|
|
|
|
def get_reduction_type(self):
|
|
return None
|
|
|
|
def is_no_op(self):
|
|
return False
|
|
|
|
def should_allocate(self):
|
|
return True
|
|
|
|
def simplify_and_reorder(self):
|
|
return (
|
|
(
|
|
self.get_size(),
|
|
(),
|
|
),
|
|
None,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class InputsKernel(Buffer):
|
|
inputs: List[Buffer]
|
|
|
|
def get_read_writes(self):
|
|
return dependencies.ReadWrites(
|
|
{dependencies.StarDep(x.get_name()) for x in self.inputs},
|
|
{dependencies.StarDep(self.get_name())},
|
|
set(),
|
|
[],
|
|
None,
|
|
)
|
|
|
|
@staticmethod
|
|
def unwrap_storage(inputs):
|
|
inputs_new = []
|
|
for x in inputs:
|
|
if isinstance(x, TensorBox):
|
|
x = x.data
|
|
if isinstance(x, StorageBox):
|
|
x = x.data
|
|
if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
|
|
x = ExternKernel.realize_input(x)
|
|
assert isinstance(x, (Buffer, ReinterpretView)), x
|
|
inputs_new.append(x)
|
|
return inputs_new
|
|
|
|
def is_extern(self):
|
|
return True
|
|
|
|
|
|
class NopKernel(InputsKernel):
|
|
def is_no_op(self):
|
|
return True
|
|
|
|
|
|
class ConcatKernel(NopKernel):
|
|
"""
|
|
There isn't actually a real kernel for concat, we just change the
|
|
storage for the upstream data.
|
|
"""
|
|
|
|
@classmethod
|
|
def create(cls, inputs, dim):
|
|
device = inputs[0].get_device()
|
|
dtype = inputs[0].get_dtype()
|
|
new_size = list(inputs[0].get_size())
|
|
offsets_start = [0]
|
|
offsets_end = [new_size[dim]]
|
|
assert 0 <= dim < len(new_size)
|
|
for i in range(1, len(inputs)):
|
|
input_size = inputs[i].get_size()
|
|
offsets_start.append(new_size[dim])
|
|
assert len(input_size) == len(new_size)
|
|
assert inputs[i].get_dtype() == dtype
|
|
assert inputs[i].get_device() == device
|
|
for j in range(len(new_size)):
|
|
if j == dim:
|
|
new_size[j] = new_size[j] + input_size[j]
|
|
else:
|
|
new_size[j] = V.graph.sizevars.guard_equals(
|
|
new_size[j], input_size[j]
|
|
)
|
|
offsets_end.append(new_size[dim])
|
|
|
|
output_stride = FlexibleLayout.contiguous_strides(new_size)
|
|
# If any of the inputs is in CL format, use CL format for the output
|
|
for i in range(len(inputs)):
|
|
x = inputs[i]
|
|
if is_storage_and_layout(x):
|
|
layout = x.get_layout()
|
|
if (
|
|
isinstance(layout, FixedLayout)
|
|
and layout.is_channels_last_contiguous()
|
|
):
|
|
# use CL stride for the output
|
|
output_stride = make_channels_last_strides_for(new_size)
|
|
break
|
|
|
|
kernel = ConcatKernel(
|
|
name=None,
|
|
layout=FixedLayout(
|
|
device=device,
|
|
dtype=dtype,
|
|
size=new_size,
|
|
stride=output_stride,
|
|
),
|
|
inputs=[],
|
|
)
|
|
kernel = StorageBox(kernel)
|
|
for i in range(len(inputs)):
|
|
kernel.data.inputs.append(
|
|
cls.realize_into(
|
|
inputs[i],
|
|
SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
|
|
)
|
|
)
|
|
kernel.data.name = V.graph.register_buffer(kernel.data)
|
|
kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
|
|
|
|
return kernel
|
|
|
|
@classmethod
|
|
def realize_into(cls, src, dst):
|
|
# Attempt to turn this into a ReinterpretView rather than assert.
|
|
# This has concessions around layout, as as_storage_and_layout
|
|
# can cause us to go from flexible to fixed layout.
|
|
if not isinstance(dst, ReinterpretView):
|
|
if is_storage_and_layout(dst):
|
|
storage, layout = as_storage_and_layout(dst)
|
|
dst = ReinterpretView(storage, layout)
|
|
assert isinstance(dst, ReinterpretView), dst
|
|
if isinstance(src, TensorBox):
|
|
# unwrap a TensorBox
|
|
return cls.realize_into(src.data, dst)
|
|
if isinstance(src, StorageBox):
|
|
src.realize()
|
|
# ExternKernelAlloc has specific requirements for output layout, should create a copy
|
|
if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
|
|
src.data, ExternKernelAlloc
|
|
):
|
|
src.data.layout = AliasedLayout(dst)
|
|
return src.data
|
|
# introduce a copy
|
|
pw = Pointwise.create(
|
|
device=src.get_device(),
|
|
dtype=src.get_dtype(),
|
|
inner_fn=src.make_loader(),
|
|
ranges=[
|
|
V.graph.sizevars.guard_equals(a, b)
|
|
for a, b in zip(src.get_size(), dst.get_size())
|
|
],
|
|
)
|
|
return cls.realize_into(pw, dst)
|
|
|
|
def should_allocate(self):
|
|
return True
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExternKernel(InputsKernel):
|
|
constant_args: Tuple[Any, ...] = ()
|
|
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
output_view: Optional[ReinterpretView] = None
|
|
|
|
def decide_layout(self):
|
|
if isinstance(self.layout, FlexibleLayout):
|
|
self.apply_constraint()
|
|
self.freeze_layout()
|
|
|
|
def codegen(self, wrapper):
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def copy_input(x):
|
|
pw = Pointwise.create(
|
|
device=x.get_device(),
|
|
dtype=x.get_dtype(),
|
|
inner_fn=x.make_loader(),
|
|
ranges=x.get_size(),
|
|
)
|
|
pw.realize()
|
|
return pw
|
|
|
|
@classmethod
|
|
def process_kernel(cls, kernel, *args, **kwargs):
|
|
binded_args = signature(kernel).bind(*args, **kwargs).arguments
|
|
args_flat, args_spec = pytree.tree_flatten(binded_args)
|
|
|
|
is_arg_tensor = []
|
|
tensor_args = []
|
|
non_tensor_args = []
|
|
for arg in args_flat:
|
|
is_arg_tensor.append(isinstance(arg, IRNode))
|
|
if is_arg_tensor[-1]:
|
|
tensor_args.append(arg)
|
|
else:
|
|
if isinstance(arg, sympy.Expr):
|
|
arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
|
|
non_tensor_args.append(arg)
|
|
|
|
def unflatten_args(new_tensor_args, new_non_tensor_args):
|
|
result = []
|
|
it_tensors = iter(new_tensor_args)
|
|
it_non_tensors = iter(new_non_tensor_args)
|
|
for is_tensor in is_arg_tensor:
|
|
if is_tensor:
|
|
result.append(next(it_tensors))
|
|
else:
|
|
result.append(next(it_non_tensors))
|
|
result = pytree.tree_unflatten(result, args_spec)
|
|
return result.get("args", []), result.get("kwargs", {})
|
|
|
|
tensor_args = [cls.realize_input(x) for x in tensor_args]
|
|
|
|
# freeze layout otherwise our output stride calculation might
|
|
# become incorrect
|
|
for x in tensor_args:
|
|
if is_storage_and_layout(x):
|
|
as_storage_and_layout(x, freeze=True)
|
|
|
|
# We don't have generic shape formulas, so just burn in the
|
|
# shapes and run an example input.
|
|
# TODO(jansel): replace this with dynamic shape formulas
|
|
example_args = []
|
|
|
|
# We need to retain the constant values of fake tensors that we originally
|
|
# propagated the graph with, because for some operators running without a
|
|
# constant would trigger an error / DataDependentException
|
|
for x in tensor_args:
|
|
if x.get_name() in V.graph.constants:
|
|
example_args.append(V.graph.constants[x.get_name()])
|
|
else:
|
|
example_args.append(ir_node_to_tensor(x, guard_shape=True))
|
|
|
|
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
|
|
example_output = kernel(*new_args, **new_kwargs)
|
|
|
|
return example_output, tensor_args, non_tensor_args, unflatten_args
|
|
|
|
@classmethod
|
|
def convert_to_reinterpret_view(cls, x):
|
|
"""
|
|
In order to pass this to an extern kernel we need a
|
|
ReinterpretView not a View. This allows us to avoid some
|
|
unneeded copies.
|
|
"""
|
|
assert isinstance(x, BaseView)
|
|
if isinstance(x, ReinterpretView):
|
|
return x
|
|
|
|
x.unwrap_view().freeze_layout()
|
|
rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False)
|
|
assert len(rw.reads) == 1
|
|
|
|
index = V.graph.sizevars.simplify_with_ranges(
|
|
list(rw.reads)[0].index, rw.var_ranges
|
|
)
|
|
strides = V.graph.sizevars.stride_vars(index, rw.range_vars)
|
|
offset = V.graph.sizevars.offset_var(index, rw.range_vars)
|
|
expected = sympy_dot(rw.range_vars, strides) + offset
|
|
|
|
if index != expected:
|
|
log.debug(
|
|
"convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
|
|
strides,
|
|
offset,
|
|
index,
|
|
)
|
|
raise NotImplementedError()
|
|
|
|
return ReinterpretView(
|
|
data=x.data,
|
|
layout=FixedLayout(
|
|
device=x.get_device(),
|
|
dtype=x.get_dtype(),
|
|
size=x.get_size(),
|
|
stride=strides,
|
|
offset=offset,
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def realize_input(cls, x):
|
|
if x is None:
|
|
return NoneAsConstantBuffer()
|
|
if isinstance(x, (sympy.Expr, int)):
|
|
return ShapeAsConstantBuffer(x)
|
|
if isinstance(x, Constant):
|
|
return V.graph.add_tensor_constant(
|
|
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
|
)
|
|
if isinstance(x, ConstantBuffer):
|
|
return x
|
|
if isinstance(x, TensorBox):
|
|
return cls.realize_input(x.data)
|
|
if isinstance(x, ReinterpretView):
|
|
return x
|
|
if isinstance(x, BaseView):
|
|
x.realize()
|
|
if is_storage_and_layout(x.unwrap_view()) and not isinstance(
|
|
x.unwrap_view().data, ExternKernelAlloc
|
|
):
|
|
try:
|
|
return cls.convert_to_reinterpret_view(x)
|
|
except NotImplementedError:
|
|
pass
|
|
if isinstance(x, StorageBox):
|
|
# TODO(jansel): impose layout preference on realized buffer
|
|
x.realize()
|
|
return x
|
|
return cls.copy_input(x)
|
|
|
|
@classmethod
|
|
def require_stride1(cls, x):
|
|
if is_storage_and_layout(x):
|
|
if len(x.get_stride()) == 0:
|
|
return x
|
|
for stride in x.get_stride():
|
|
if stride == 1:
|
|
return x
|
|
return cls.copy_input(x)
|
|
|
|
@classmethod
|
|
def require_stride_order(cls, x, order):
|
|
if x.get_numel() == 0: # Layout doesn't matter
|
|
return x
|
|
|
|
# require x to have the layout as strided_ordered as order
|
|
if is_storage_and_layout(x):
|
|
if isinstance(x.get_layout(), FlexibleLayout):
|
|
# fix flexiblelayout to be FixedLayout with stride_order
|
|
as_storage_and_layout(
|
|
x, freeze=True, want_contiguous=False, stride_order=order
|
|
)
|
|
return x
|
|
elif isinstance(
|
|
x.get_layout(), FixedLayout
|
|
) and x.get_layout().is_stride_ordered(order):
|
|
return x
|
|
elif isinstance(x.get_layout(), MutationLayout):
|
|
if isinstance(x.get_layout().real_layout(), FlexibleLayout):
|
|
raise AssertionError(
|
|
"the MutationLayout's real layout shouldn't be FlexibleLayout"
|
|
)
|
|
elif isinstance(
|
|
x.get_layout().real_layout(), FixedLayout
|
|
) and x.get_layout().real_layout().is_stride_ordered(order):
|
|
return x
|
|
|
|
# TODO - Storage to InputBuffer
|
|
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
|
|
return x
|
|
x = cls.copy_input(x)
|
|
as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
|
|
assert is_stride_order_storage_and_layout(x, order)
|
|
return x
|
|
|
|
@classmethod
|
|
def require_contiguous(cls, x):
|
|
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
|
|
|
|
def apply_constraint(self):
|
|
pass
|
|
|
|
def codegen_args(self):
|
|
args = [x.codegen_reference() for x in self.inputs]
|
|
args.extend(map(repr, self.constant_args))
|
|
return args
|
|
|
|
def cpp_wrapper_codegen_args(self):
|
|
args = [x.codegen_reference() for x in self.inputs]
|
|
args.extend(self.cpp_constant_args)
|
|
return args
|
|
|
|
def codegen_kwargs(self):
|
|
kwargs = []
|
|
if self.kwargs:
|
|
if V.graph.cpp_wrapper:
|
|
for arg_name in self.ordered_kwargs_for_cpp_kernel:
|
|
assert arg_name in self.kwargs, (
|
|
"arg %s not found in self.kwargs" % arg_name
|
|
)
|
|
v = self.kwargs.get(arg_name)
|
|
kwargs.append(repr(v))
|
|
else:
|
|
kwargs = [f"{k}={repr(v)}" for k, v in self.kwargs.items()]
|
|
return kwargs
|
|
|
|
def codegen_size_asserts(self, wrapper):
|
|
if config.size_asserts:
|
|
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
|
|
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
|
|
wrapper.writeline(
|
|
f"assert_size_stride({self.get_name()}, {size}, {stride})"
|
|
)
|
|
|
|
def get_group_stride(self):
|
|
"""
|
|
get output sizes and strides, for template_codegen
|
|
"""
|
|
_size = self.get_size()
|
|
_stride = self.get_stride()
|
|
# iter_ranges = _size of output tensor, reduce_range = [] because no reduction
|
|
return [_size, []], _stride
|
|
|
|
def canonicalize(self):
|
|
"""
|
|
Manually get cononicalization of the output index
|
|
"""
|
|
# manually generate index formula for conv
|
|
sizevars = V.graph.sizevars
|
|
sizes = self.get_size()
|
|
strides = self.get_stride()
|
|
strides = [sizevars.size_hint(x) for x in strides]
|
|
index_vars = [sympy_symbol(f"d{i}") for i in range(len(sizes))]
|
|
# reorder index vars according to stride
|
|
index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
|
|
lookup = {pos: idx for idx, pos in enumerate(index_order)}
|
|
order = [lookup[i] for i in range(len(lookup))]
|
|
index_vars = [index_vars[i] for i in order]
|
|
indexer = self.make_indexer()
|
|
index = indexer(index_vars)
|
|
|
|
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
|
|
index_vars, sizes, [index]
|
|
)
|
|
|
|
# assign new variables each dimension to deal with numbering mismatches
|
|
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1
|
|
_, add_var = var_builder("c")
|
|
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
|
|
|
|
index = sympy_subs(sympy.expand(index), replacement)
|
|
return index, tuple(new_sizes)
|
|
|
|
def __str__(self):
|
|
lines = [
|
|
f"{field.name}={getattr(self, field.name)}"
|
|
for field in dataclasses.fields(self)
|
|
]
|
|
return self.str_helper(lines)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExternKernelOut(ExternKernel):
|
|
output_view: Optional[ReinterpretView] = None
|
|
|
|
def codegen(self, wrapper):
|
|
args = self.codegen_args()
|
|
kwargs = self.codegen_kwargs()
|
|
if kwargs:
|
|
args.extend(kwargs)
|
|
|
|
wrapper.generate_extern_kernel_out(
|
|
self.output_view,
|
|
self.codegen_reference(),
|
|
args,
|
|
self.kernel,
|
|
self.cpp_kernel,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kwargs=None,
|
|
output_view=None,
|
|
kernel=None,
|
|
cpp_kernel=None,
|
|
ordered_kwargs_for_cpp_kernel=(),
|
|
):
|
|
super().__init__(
|
|
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
|
|
)
|
|
self.output_view = output_view
|
|
self.cpp_kernel = cpp_kernel
|
|
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
|
self.name = V.graph.register_buffer(self)
|
|
if kernel is not None:
|
|
self.kernel = kernel
|
|
|
|
def should_allocate(self):
|
|
return True
|
|
|
|
|
|
class ExternKernelAlloc(ExternKernel):
|
|
def codegen(self, wrapper):
|
|
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
|
wrapper.writeline(f"{self.get_name()} = {self.kernel}({', '.join(args)})")
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kwargs=None,
|
|
kernel=None,
|
|
cpp_kernel=None,
|
|
ordered_kwargs_for_cpp_kernel=(),
|
|
cpp_constant_args=(),
|
|
):
|
|
super().__init__(
|
|
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
|
|
)
|
|
self.cpp_kernel = cpp_kernel
|
|
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
|
|
self.cpp_constant_args = cpp_constant_args
|
|
self.name = V.graph.register_buffer(self)
|
|
if kernel is not None:
|
|
self.kernel = kernel
|
|
|
|
def should_allocate(self):
|
|
return False
|
|
|
|
def apply_constraint(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class InplaceBernoulliFallback(ExternKernel):
|
|
"""
|
|
This needs to be a custom class to handle mutation properly
|
|
"""
|
|
|
|
kernel = "aten.bernoulli_"
|
|
|
|
def codegen(self, wrapper):
|
|
(x,) = [t.codegen_reference() for t in self.inputs]
|
|
wrapper.writeline(
|
|
f"{self.kernel}({x}, {', '.join(map(repr, self.constant_args))})"
|
|
)
|
|
|
|
def should_allocate(self):
|
|
return False
|
|
|
|
def get_mutation_names(self):
|
|
assert isinstance(self.layout, MutationLayout)
|
|
return (self.layout.target.get_name(),)
|
|
|
|
def __init__(self, x, *constant_args):
|
|
super().__init__(
|
|
None,
|
|
MutationLayout(x),
|
|
self.unwrap_storage([x]),
|
|
constant_args,
|
|
)
|
|
self.name = V.graph.register_buffer(self)
|
|
|
|
|
|
class ScatterFallback(ExternKernel):
|
|
"""
|
|
This needs to be a custom class to handle mutation properly.
|
|
This class handles both aten.scatter_ and aten.scatter_reduce_.
|
|
It also handle the case `src` being a scalar properly.
|
|
"""
|
|
|
|
def codegen(self, wrapper):
|
|
if self.src_is_tensor:
|
|
(x, index, src) = [t.codegen_reference() for t in self.inputs]
|
|
else:
|
|
(x, index) = [t.codegen_reference() for t in self.inputs]
|
|
src = self.constant_args[1]
|
|
line = f"{self.kernel}({x}, {self.constant_args[0]}, {index}, {src}"
|
|
if self.kernel == "aten.scatter_":
|
|
if self.kwargs["reduce"]:
|
|
line += f", reduce={repr(self.kwargs['reduce'])}"
|
|
else:
|
|
line += ", ".join([""] + self.codegen_kwargs())
|
|
line += ")"
|
|
wrapper.writeline(line)
|
|
|
|
def should_allocate(self):
|
|
return False
|
|
|
|
def __init__(
|
|
self,
|
|
fn,
|
|
x,
|
|
dim: int,
|
|
index,
|
|
src,
|
|
*,
|
|
reduce: str = None,
|
|
include_self: bool = True,
|
|
):
|
|
assert fn in {"aten.scatter_", "aten.scatter_reduce_"}
|
|
self.kernel = fn
|
|
self.src_is_tensor = isinstance(src, TensorBox)
|
|
if self.src_is_tensor:
|
|
tensors = [self.realize_input(t) for t in [x, index, src]]
|
|
constant_args = [dim]
|
|
else:
|
|
tensors = [self.realize_input(t) for t in [x, index]]
|
|
constant_args = [dim, src]
|
|
super().__init__(
|
|
None,
|
|
MutationLayout(x),
|
|
self.unwrap_storage(tensors),
|
|
constant_args,
|
|
{"reduce": reduce, "include_self": include_self},
|
|
)
|
|
self.name = V.graph.register_buffer(self)
|
|
|
|
|
|
class IndexPutFallback(ExternKernel):
|
|
"""
|
|
This needs to be a custom class to handle mutation and indices properly
|
|
"""
|
|
|
|
kernel = "aten.index_put_"
|
|
|
|
def codegen(self, wrapper):
|
|
(x, values, *valid_indices) = [t.codegen_reference() for t in self.inputs]
|
|
indices = []
|
|
iter_valid_indices = iter(valid_indices)
|
|
for i, _ in enumerate(self.indices):
|
|
if self.indices[i] is not None:
|
|
indices.append(next(iter_valid_indices))
|
|
else:
|
|
indices.append("None")
|
|
wrapper.writeline(
|
|
f"{self.kernel}({x}, [{','.join(indices)}], {values}, {repr(self.constant_args[0])})"
|
|
)
|
|
|
|
def should_allocate(self):
|
|
return False
|
|
|
|
def __init__(self, x, indices, values, accumulate):
|
|
self.indices = indices
|
|
valid_indices = [i for i in indices if i is not None]
|
|
tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
|
|
super().__init__(
|
|
None,
|
|
MutationLayout(x),
|
|
self.unwrap_storage(tensors),
|
|
[accumulate],
|
|
)
|
|
self.name = V.graph.register_buffer(self)
|
|
|
|
|
|
class DeviceCopy(ExternKernelOut):
|
|
@classmethod
|
|
def create(cls, x, device):
|
|
if not x.is_extern() and all(
|
|
(r.name in V.graph.constants and hasattr(r, "index")) for r in x.get_reads()
|
|
):
|
|
return x.constant_to_device(device)
|
|
|
|
V.graph.device_types.add(device.type)
|
|
V.graph.add_device_idx(device.index)
|
|
V.graph.device_types.add(x.get_device().type)
|
|
V.graph.add_device_idx(x.get_device().index)
|
|
|
|
developer_warning("DeviceCopy in input program")
|
|
return DeviceCopy(
|
|
FlexibleLayout(
|
|
device=device,
|
|
dtype=x.get_dtype(),
|
|
size=x.get_size(),
|
|
),
|
|
[cls.realize_input(x)],
|
|
)
|
|
|
|
def codegen(self, wrapper):
|
|
args = self.codegen_args()
|
|
assert len(args) == 1
|
|
if self.output_view:
|
|
wrapper.writeline(
|
|
f"{self.output_view.codegen_reference()}.copy_({args[0]})"
|
|
)
|
|
else:
|
|
wrapper.writeline(f"{self.codegen_reference()}.copy_({args[0]})")
|
|
|
|
|
|
class DynamicScalar(IRNode):
|
|
"""
|
|
The result of a call to aten._local_scalar_dense.
|
|
|
|
This is not yet implemented. The one model (so far) that calls this
|
|
(fastNLP_Bert) does not actually use the result. So we expect this
|
|
node to get dead code eliminated.
|
|
"""
|
|
|
|
def get_reads(self):
|
|
return ()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FallbackKernel(ExternKernelAlloc):
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
kernel,
|
|
tensor_args,
|
|
nontensor_args,
|
|
unflatten_args,
|
|
kwargs=None,
|
|
):
|
|
super().__init__(
|
|
layout,
|
|
tuple(tensor_args),
|
|
tuple(nontensor_args),
|
|
)
|
|
if getattr(torch.ops.aten, kernel.__name__, None) is kernel:
|
|
self.kernel = f"aten.{kernel.__name__}"
|
|
else:
|
|
self.kernel = (
|
|
f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
|
|
)
|
|
self.unflatten_args = unflatten_args
|
|
self.kwargs = {} if kwargs is None else kwargs
|
|
V.graph.warn_fallback(self.kernel)
|
|
|
|
def codegen_args(self):
|
|
@dataclasses.dataclass
|
|
class Shim:
|
|
ref: Any
|
|
|
|
def __repr__(self):
|
|
return self.ref
|
|
|
|
def gen_kwarg(k, v):
|
|
return f"{k}={repr(v)}"
|
|
|
|
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
|
|
constant_args = [Shim(repr(x)) for x in self.constant_args]
|
|
args, kwargs = self.unflatten_args(tensor_args, constant_args)
|
|
return list(map(repr, args)) + [gen_kwarg(k, v) for k, v in kwargs.items()]
|
|
|
|
@classmethod
|
|
def create(cls, kernel, *args, **kwargs):
|
|
fake_incorrect_kernels = (
|
|
aten._fft_r2c.default,
|
|
aten._fft_r2c.out,
|
|
aten._fft_c2r.default,
|
|
aten._fft_c2c.default,
|
|
aten._fft_c2c.out,
|
|
aten._linalg_svd.default,
|
|
aten._linalg_svd.U,
|
|
aten._fused_moving_avg_obs_fq_helper_functional,
|
|
)
|
|
context = (
|
|
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
|
|
)
|
|
with context:
|
|
(
|
|
example_output,
|
|
tensor_args,
|
|
non_tensor_args,
|
|
unflatten_args,
|
|
) = cls.process_kernel(kernel, *args, **kwargs)
|
|
|
|
assert tensor_args or isinstance(
|
|
example_output, torch.Tensor
|
|
), "Not sure where to find device info"
|
|
packed = FallbackKernel(
|
|
MultiOutputLayout(
|
|
tensor_args[0].get_device() if tensor_args else example_output.device
|
|
),
|
|
kernel,
|
|
tensor_args,
|
|
non_tensor_args,
|
|
unflatten_args,
|
|
)
|
|
|
|
def generate_output(output, index=""):
|
|
if isinstance(output, (list, tuple)):
|
|
return type(output)(
|
|
generate_output(output[i], f"{index}[{i}]")
|
|
for i in range(len(output))
|
|
)
|
|
elif isinstance(output, torch.Tensor):
|
|
return MultiOutput(
|
|
FixedLayout(
|
|
output.device,
|
|
output.dtype,
|
|
convert_shape_to_inductor(output.size()),
|
|
convert_shape_to_inductor(output.stride()),
|
|
),
|
|
packed,
|
|
index,
|
|
)
|
|
elif isinstance(output, int):
|
|
return output
|
|
else:
|
|
assert output is None, "FallbackKernel output type is not supported"
|
|
return None
|
|
|
|
return generate_output(example_output)
|
|
|
|
def apply_constraint(self):
|
|
return super().apply_constraint()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MultiOutputLayout(IRNode):
|
|
device: torch.device
|
|
|
|
|
|
class MultiOutput(ExternKernel):
|
|
def codegen(self, wrapper):
|
|
wrapper.writeline(
|
|
f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
|
|
)
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
def __init__(self, layout, input, index: str):
|
|
super().__init__(None, layout, [input], ())
|
|
self.name = V.graph.register_buffer(self)
|
|
self.index = index
|
|
|
|
def should_allocate(self):
|
|
return False
|
|
|
|
|
|
def _string(shape: tuple):
|
|
from .codegen.wrapper import CppWrapperCodeGen
|
|
|
|
cpp_wrapper_codegen = CppWrapperCodeGen()
|
|
return cpp_wrapper_codegen.codegen_shape_tuple(shape)
|
|
|
|
|
|
def _prepare_convolution_fusion_create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: List[int],
|
|
stride_: List[int],
|
|
dilation_: List[int],
|
|
groups: int,
|
|
transposed: bool = False,
|
|
output_padding_: List[int] = None,
|
|
):
|
|
"""
|
|
This function is a helper function to prepare inputs, layout and constant args
|
|
for convolution post-op fusion's create function, including deciding the output
|
|
layout (channels first or channels last), realizing inputs and make them etc. The
|
|
function only supports the CPU device since conv post-op fusion kernel is only
|
|
supported on CPU right now.
|
|
"""
|
|
|
|
# Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
|
|
def _conv_input_size(
|
|
output_size, weight_size, padding, output_padding, stride, dilation, groups
|
|
):
|
|
assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
|
|
dim = len(output_size)
|
|
assert dim > 2, "Expect input dim > 2"
|
|
|
|
BATCH_DIM = 0
|
|
WEIGHT_INPUT_CHANNELS_DIM = 1
|
|
input_size = []
|
|
input_size.append(output_size[BATCH_DIM])
|
|
input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
|
|
for d in range(2, dim):
|
|
kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
|
|
input_size_d = (
|
|
(output_size[d] - 1) * stride[d - 2]
|
|
- (padding[d - 2] * 2)
|
|
+ kernel
|
|
+ output_padding[d - 2]
|
|
)
|
|
input_size.append(input_size_d)
|
|
return list(map(int, input_size))
|
|
|
|
# The size of prepacked_weight is the prepacked weight size of deconv:
|
|
# Groups > 1: [g*o, i/g, ...]
|
|
# Groups == 1: [o, i, ...]
|
|
# Returns original weight size in [i, o, ...]
|
|
def _original_deconv_weight_size(
|
|
prepacked_weight,
|
|
groups,
|
|
):
|
|
prepacked_weight_size = prepacked_weight.size()
|
|
dim = len(prepacked_weight_size)
|
|
assert dim > 2, "Expect weight dim > 2"
|
|
if groups > 1:
|
|
weight_size = []
|
|
weight_size.append(prepacked_weight_size[1] * groups)
|
|
weight_size.append(prepacked_weight_size[0] / groups)
|
|
for d in range(2, dim):
|
|
weight_size.append(prepacked_weight_size[d])
|
|
else:
|
|
weight_size = prepacked_weight.transpose(0, 1).size()
|
|
return weight_size
|
|
|
|
stride = tuple(stride_)
|
|
padding = tuple(padding_)
|
|
dilation = tuple(dilation_)
|
|
assert isinstance(groups, int)
|
|
output_padding = tuple(output_padding_) if output_padding_ else (0, 0)
|
|
x.realize()
|
|
weight.realize()
|
|
with V.graph.fake_mode:
|
|
x_fake = ir_node_to_tensor(x, guard_shape=True)
|
|
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
|
|
if transposed:
|
|
# When transposed, the size of the prepacked oneDNN weight is different
|
|
# from the PyTorch weight. We're not able to run aten conv with such
|
|
# size. We infer the output size from the input params here:
|
|
weight_size = _original_deconv_weight_size(weight_fake, groups)
|
|
input_size = x_fake.size()
|
|
output_size = _conv_input_size(
|
|
input_size,
|
|
weight_size,
|
|
padding,
|
|
output_padding,
|
|
stride,
|
|
dilation,
|
|
groups,
|
|
)
|
|
else:
|
|
bias_fake = (
|
|
ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
|
|
)
|
|
output = torch.ops.aten.convolution(
|
|
x_fake,
|
|
weight_fake,
|
|
bias_fake,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
)
|
|
output_size = output.size()
|
|
|
|
req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
|
|
req_stride_order = [len(req_stride_order)] + req_stride_order
|
|
output_stride = make_channels_last_strides_for(output_size)
|
|
|
|
x = cls.require_stride_order(x, req_stride_order)
|
|
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
|
|
inputs = [x, weight]
|
|
|
|
kernel_layout = FixedLayout(
|
|
x.get_device(),
|
|
x.get_dtype(),
|
|
convert_shape_to_inductor(output_size),
|
|
convert_shape_to_inductor(output_stride),
|
|
)
|
|
constant_args = [padding, stride, dilation, groups]
|
|
if transposed:
|
|
constant_args.insert(1, output_padding)
|
|
|
|
if bias is not None:
|
|
inputs.append(bias)
|
|
else:
|
|
constant_args.insert(0, bias)
|
|
return inputs, constant_args, kernel_layout, req_stride_order
|
|
|
|
|
|
class ConvolutionUnary(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._convolution_pointwise"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kernel="torch.ops.mkldnn._convolution_pointwise",
|
|
):
|
|
super().__init__(layout, inputs, constant_args)
|
|
self.kernel = kernel
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.writeline(
|
|
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
|
)
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: List[int],
|
|
stride_: List[int],
|
|
dilation_: List[int],
|
|
groups: int,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
kernel = "torch.ops.mkldnn._convolution_pointwise"
|
|
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
|
|
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
|
)
|
|
constant_args = constant_args + [attr, scalars, algorithm]
|
|
return ConvolutionUnary(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
|
|
class ConvolutionBinary(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kernel="torch.ops.mkldnn._convolution_pointwise.binary",
|
|
):
|
|
super().__init__(layout, inputs, constant_args)
|
|
self.kernel = kernel
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.writeline(
|
|
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
|
)
|
|
if isinstance(self.layout, Layout):
|
|
self.codegen_size_asserts(wrapper)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
other: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: List[int],
|
|
stride_: List[int],
|
|
dilation_: List[int],
|
|
groups: int,
|
|
binary_attr: str,
|
|
binary_alpha: Optional[float],
|
|
unary_attr: Optional[str],
|
|
unary_scalars: Optional[List],
|
|
unary_algorithm: Optional[str],
|
|
):
|
|
kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
req_stride_order,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
|
)
|
|
other = cls.require_stride_order(other, req_stride_order)
|
|
inputs.insert(1, other)
|
|
constant_args = constant_args + [
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
]
|
|
return ConvolutionBinary(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
|
|
class ConvolutionBinaryInplace(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
|
|
):
|
|
super().__init__(kernel_layout, inputs, constant_args)
|
|
self.kernel = kernel
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.writeline(
|
|
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
|
)
|
|
|
|
def get_mutation_names(self):
|
|
assert isinstance(self.layout, MutationLayout)
|
|
return (self.layout.target.get_name(),)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
other: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: List[int],
|
|
stride_: List[int],
|
|
dilation_: List[int],
|
|
groups: int,
|
|
binary_attr: str,
|
|
binary_alpha: Optional[float],
|
|
unary_attr: Optional[str],
|
|
unary_scalars: Optional[List],
|
|
unary_algorithm: Optional[str],
|
|
):
|
|
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
|
|
(inputs, constant_args, _, _) = _prepare_convolution_fusion_create(
|
|
cls, x, weight, bias, padding_, stride_, dilation_, groups
|
|
)
|
|
other = cls.realize_input(other)
|
|
V.graph.realize_users_of(other.get_name())
|
|
inputs.insert(1, other)
|
|
constant_args = constant_args + [
|
|
binary_attr,
|
|
binary_alpha,
|
|
unary_attr,
|
|
unary_scalars,
|
|
unary_algorithm,
|
|
]
|
|
return ConvolutionBinaryInplace(
|
|
kernel_layout=MutationLayout(inputs[1]),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
|
|
class MKLPackedLinear(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkl._mkl_linear"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
cpp_constant_args=(),
|
|
kernel="torch.ops.mkl._mkl_linear",
|
|
cpp_kernel="mkl::_mkl_linear",
|
|
):
|
|
super().__init__(layout, inputs, constant_args, None, kernel, cpp_kernel)
|
|
self.cpp_kernel_key = "mkl_linear"
|
|
self.cpp_op_schema = """
|
|
at::Tensor(
|
|
const at::Tensor& self,
|
|
const at::Tensor& mkl_weight_t,
|
|
const at::Tensor& origin_weight_t,
|
|
const c10::optional<at::Tensor>& bias_opt,
|
|
const int64_t prepack_batch_size)"""
|
|
self.cpp_constant_args = cpp_constant_args
|
|
|
|
def codegen(self, wrapper):
|
|
from torch._inductor.codegen.wrapper import CppWrapperCodeGen
|
|
|
|
if isinstance(wrapper, CppWrapperCodeGen):
|
|
args = self.cpp_wrapper_codegen_args()
|
|
else:
|
|
args = self.codegen_args()
|
|
|
|
wrapper.generate_fusion_ops_code(
|
|
self.get_name(),
|
|
self.kernel,
|
|
self.cpp_kernel,
|
|
args,
|
|
self.cpp_op_schema,
|
|
self.cpp_kernel_key,
|
|
)
|
|
|
|
@classmethod
|
|
def create(cls, x, packed_w, orig_w, batch_size):
|
|
kernel = "torch.ops.mkl._mkl_linear"
|
|
|
|
x = cls.require_stride1(cls.realize_input(x))
|
|
orig_w = cls.require_stride1(cls.realize_input(orig_w))
|
|
*m, _ = x.get_size()
|
|
oc, _ = orig_w.get_size()
|
|
output_size = list(m) + [oc]
|
|
output_stride = make_contiguous_strides_for(output_size)
|
|
inputs = [x, packed_w, orig_w]
|
|
bias = None
|
|
cpp_bias = "at::Tensor()"
|
|
constant_args = [bias, batch_size]
|
|
cpp_constant_args = [cpp_bias, str(batch_size)]
|
|
|
|
return MKLPackedLinear(
|
|
layout=FixedLayout(
|
|
x.get_device(), x.get_dtype(), output_size, output_stride
|
|
),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
cpp_constant_args=cpp_constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
|
|
class LinearUnary(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._linear_pointwise"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kernel="torch.ops.mkldnn._linear_pointwise",
|
|
cpp_kernel="mkldnn::_linear_pointwise",
|
|
cpp_constant_args=(),
|
|
):
|
|
super().__init__(layout, inputs, constant_args, None, kernel, cpp_kernel)
|
|
self.cpp_kernel_key = "linear_pointwise"
|
|
self.cpp_op_schema = """
|
|
at::Tensor(
|
|
const at::Tensor& input_t,
|
|
const at::Tensor& weight_t,
|
|
const c10::optional<at::Tensor>& bias_opt,
|
|
c10::string_view attr,
|
|
torch::List<c10::optional<at::Scalar>> scalars,
|
|
c10::optional<c10::string_view> algorithm)"""
|
|
self.cpp_constant_args = cpp_constant_args
|
|
|
|
def codegen(self, wrapper):
|
|
from torch._inductor.codegen.wrapper import CppWrapperCodeGen
|
|
|
|
if isinstance(wrapper, CppWrapperCodeGen):
|
|
args = self.cpp_wrapper_codegen_args()
|
|
else:
|
|
args = self.codegen_args()
|
|
|
|
wrapper.generate_fusion_ops_code(
|
|
self.get_name(),
|
|
self.kernel,
|
|
self.cpp_kernel,
|
|
args,
|
|
self.cpp_op_schema,
|
|
self.cpp_kernel_key,
|
|
)
|
|
|
|
@classmethod
|
|
def create(cls, x, w, b, attr, scalars, algorithm):
|
|
kernel = "torch.ops.mkldnn._linear_pointwise"
|
|
x = cls.require_stride1(cls.realize_input(x))
|
|
w = cls.require_stride1(cls.realize_input(w))
|
|
|
|
*m, ic = x.get_size()
|
|
oc, ic = w.get_size()
|
|
|
|
inputs = [x, w]
|
|
constant_args = [attr, scalars, algorithm]
|
|
cpp_constant_args = [
|
|
f'"{attr}"',
|
|
_string(scalars) if scalars else "{-1}",
|
|
f'"{algorithm}"',
|
|
]
|
|
if b is not None:
|
|
b = cls.require_stride1(cls.realize_input(b))
|
|
inputs.append(b)
|
|
else:
|
|
constant_args.insert(0, b)
|
|
cpp_constant_args.insert(0, "at::Tensor()")
|
|
|
|
return LinearUnary(
|
|
layout=FlexibleLayout(
|
|
device=x.get_device(),
|
|
dtype=x.get_dtype(),
|
|
size=list(m) + [oc],
|
|
),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
cpp_constant_args=cpp_constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
def apply_constraint(self):
|
|
pass
|
|
|
|
|
|
class LinearBinary(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._linear_pointwise.binary"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kernel="torch.ops.mkldnn._linear_pointwise.binary",
|
|
):
|
|
super().__init__(layout, inputs, constant_args)
|
|
self.kernel = kernel
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.writeline(
|
|
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
|
)
|
|
|
|
@classmethod
|
|
def create(cls, x, y, w, b, attr):
|
|
kernel = "torch.ops.mkldnn._linear_pointwise.binary"
|
|
x = cls.require_stride1(cls.realize_input(x))
|
|
y = cls.require_stride1(cls.realize_input(y))
|
|
w = cls.require_stride1(cls.realize_input(w))
|
|
|
|
*m, ic = x.get_size()
|
|
oc, ic = w.get_size()
|
|
|
|
inputs = [x, y, w]
|
|
constant_args = [attr]
|
|
if b is not None:
|
|
b = cls.require_stride1(cls.realize_input(b))
|
|
inputs.append(b)
|
|
else:
|
|
constant_args.insert(0, b)
|
|
|
|
return LinearBinary(
|
|
layout=FlexibleLayout(
|
|
device=x.get_device(),
|
|
dtype=x.get_dtype(),
|
|
size=list(m) + [oc],
|
|
),
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
def apply_constraint(self):
|
|
pass
|
|
|
|
|
|
class ConvolutionTransposeUnary(ExternKernelAlloc):
|
|
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
|
|
):
|
|
super().__init__(layout, inputs, constant_args)
|
|
self.kernel = kernel
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.writeline(
|
|
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
weight: "TensorBox",
|
|
bias: "TensorBox",
|
|
padding_: List[int],
|
|
output_padding_: List[int],
|
|
stride_: List[int],
|
|
dilation_: List[int],
|
|
groups_: int,
|
|
attr,
|
|
scalars,
|
|
algorithm,
|
|
):
|
|
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
|
|
transposed = True
|
|
(
|
|
inputs,
|
|
constant_args,
|
|
kernel_layout,
|
|
_,
|
|
) = _prepare_convolution_fusion_create(
|
|
cls,
|
|
x,
|
|
weight,
|
|
bias,
|
|
padding_,
|
|
stride_,
|
|
dilation_,
|
|
groups_,
|
|
transposed,
|
|
output_padding_,
|
|
)
|
|
constant_args = constant_args + [attr, scalars, algorithm]
|
|
return ConvolutionTransposeUnary(
|
|
layout=kernel_layout,
|
|
inputs=inputs,
|
|
constant_args=constant_args,
|
|
kernel=kernel,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MutableBox(IRNode):
|
|
"""
|
|
TensorBox / StorageBox allow in-place mutation of Tensors
|
|
"""
|
|
|
|
data: IRNode
|
|
|
|
def __getattr__(self, name):
|
|
fn = getattr(self.data, name)
|
|
if callable(fn):
|
|
return fn
|
|
raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
|
|
|
|
@property
|
|
def layout(self):
|
|
return self.data.layout
|
|
|
|
def __str__(self):
|
|
if isinstance(self.data, MutableBox):
|
|
line0 = f"{type(self).__name__}({type(self.data).__name__}("
|
|
endl = "))"
|
|
inner = self.data.data
|
|
else:
|
|
line0 = f"{type(self).__name__}("
|
|
inner = self.data
|
|
endl = ")"
|
|
|
|
lines = [
|
|
line0,
|
|
indent(str(inner)),
|
|
endl,
|
|
]
|
|
return "\n".join(lines)
|
|
|
|
__repr__ = __str__
|
|
|
|
|
|
class TensorBox(MutableBox):
|
|
@staticmethod
|
|
def create(data):
|
|
return TensorBox(StorageBox(data))
|
|
|
|
|
|
class StorageBox(MutableBox):
|
|
def is_input_buffer(self):
|
|
if isinstance(self.data, (InputBuffer, ReinterpretView)):
|
|
return self.data.get_name() in V.graph.graph_inputs
|
|
return False
|
|
|
|
def realize(self):
|
|
if isinstance(
|
|
self.data,
|
|
(
|
|
ComputedBuffer,
|
|
InputsKernel,
|
|
InputBuffer,
|
|
ReinterpretView,
|
|
TemplateBuffer,
|
|
),
|
|
):
|
|
return self.data.get_name()
|
|
assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)
|
|
self.data = ComputedBuffer(
|
|
name=None,
|
|
layout=FlexibleLayout(
|
|
device=self.data.get_device(),
|
|
dtype=self.data.get_dtype(),
|
|
size=self.data.get_size(),
|
|
),
|
|
data=self.data,
|
|
)
|
|
self.data.name = V.graph.register_buffer(self.data)
|
|
self.data.origins = self.origins
|
|
return self.data.name
|
|
|
|
def realize_hint(self):
|
|
"""
|
|
Called on buffers we expect to be forced to realize later.
|
|
"""
|
|
if isinstance(self.data, (Pointwise, Reduction)) and self.num_reads() > 1:
|
|
self.realize()
|
|
|
|
def has_exceeded_max_reads(self):
|
|
return isinstance(self.data, Pointwise) and (
|
|
self.num_reads() > config.realize_acc_reads_threshold
|
|
or len(self.inner_fn_str()) > config.realize_bytes_threshold
|
|
)
|
|
|
|
def mark_reuse(self, users):
|
|
"""
|
|
A heuristic to decide if we should realize a tensor
|
|
that is used multiple times.
|
|
"""
|
|
|
|
def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
|
|
"""
|
|
The heuristic for realizing reused result of heavy ops on cpu
|
|
"""
|
|
heavy_ops = ["exp"] # a list of heavy ops
|
|
fn_str = loops.inner_fn_str()
|
|
return any([(op + "(") in fn_str for op in heavy_ops])
|
|
|
|
if (
|
|
users > 1
|
|
and isinstance(self.data, (Pointwise, Reduction))
|
|
and (
|
|
self.num_reads() > config.realize_reads_threshold
|
|
or len(self.inner_fn_str()) > config.realize_bytes_threshold
|
|
or (is_cpu(self.data) and should_realize_on_cpu(self.data))
|
|
)
|
|
):
|
|
self.realize()
|
|
|
|
@cache_on_self
|
|
def num_reads(self):
|
|
data = self.data
|
|
if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
|
|
return 1
|
|
if isinstance(data, ComputedBuffer):
|
|
read_writes = data.get_read_writes()
|
|
else:
|
|
assert isinstance(data, (Pointwise, Reduction)), type(data)
|
|
read_writes = ComputedBuffer(
|
|
name=None,
|
|
layout=FlexibleLayout(
|
|
device=data.get_device(),
|
|
dtype=data.get_dtype(),
|
|
size=data.get_size(),
|
|
),
|
|
data=data,
|
|
).get_read_writes()
|
|
return len(read_writes.reads)
|
|
|
|
|
|
class InterpreterShim(torch.fx.Interpreter):
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _dummy_gm():
|
|
return torch.fx.symbolic_trace(identity)
|
|
|
|
def __init__(self, graph, submodules):
|
|
# call super() with a placeholder to avoid constructing a
|
|
# GraphModule which is very expensive (it does codegen).
|
|
super().__init__(self._dummy_gm(), garbage_collect_values=False)
|
|
self.module = self
|
|
self.graph = graph
|
|
self.submodules = submodules
|
|
self.extra_traceback = False
|
|
self.fetch_attr = submodules.__getitem__
|
|
self.current_node = None
|
|
|
|
def run_node(self, n: torch.fx.Node) -> Any:
|
|
self.current_node = n
|
|
return super().run_node(n)
|
|
|
|
def run(self, *args, **kwargs):
|
|
with V.set_interpreter_handler(self):
|
|
return super().run(*args, **kwargs)
|
|
|
|
|
|
class LoopBody:
|
|
"""
|
|
Captures the body of a Loops subclass into an FX graph. Persists any
|
|
indexing simplifications and makes it easier to analyze loop bodies.
|
|
"""
|
|
|
|
def __init__(self, fn, args, var_ranges):
|
|
super().__init__()
|
|
self.var_ranges = var_ranges
|
|
self.indexing_exprs = {}
|
|
self.indexing_exprs_name = {}
|
|
self.reads = []
|
|
self.writes = []
|
|
self.reads_name2expr = {}
|
|
self.writes_name2expr = {}
|
|
self.other = []
|
|
self.submodules = {"get_index": self.get_index}
|
|
self.subblocks = {}
|
|
self.indirect_vars = []
|
|
self.root_block = LoopBodyBlock(self, fn, args)
|
|
self.indexing = None
|
|
|
|
def debug_str(self):
|
|
lines = [f"var_ranges = {dict(self.var_ranges)}"]
|
|
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
|
|
lines.extend(
|
|
[
|
|
block.debug_str(name)
|
|
for name, block in itertools.chain(
|
|
[("body", self.root_block)], self.subblocks.items()
|
|
)
|
|
]
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
def add_index_expr(self, expr: sympy.Expr, category, buf_name):
|
|
getattr(self, category).append(expr)
|
|
if buf_name is not None:
|
|
getattr(self, f"{category}_name2expr")[buf_name] = expr
|
|
if expr not in self.indexing_exprs_name:
|
|
name = f"index{len(self.indexing_exprs)}"
|
|
self.indexing_exprs_name[expr] = name
|
|
self.indexing_exprs[name] = expr
|
|
return self.indexing_exprs_name[expr]
|
|
|
|
def add_submodule(self, block, prefix):
|
|
"""Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
|
|
if prefix[-1].isnumeric() and prefix not in self.submodules:
|
|
name = prefix
|
|
else:
|
|
name = f"{prefix}{len(self.submodules)}"
|
|
self.submodules[name] = block
|
|
return name
|
|
|
|
def add_indirect(self):
|
|
name = f"indirect{len(self.indirect_vars)}"
|
|
var = sympy_symbol(name)
|
|
self.indirect_vars.append(var)
|
|
return var
|
|
|
|
def replace_indirect(self, old, new):
|
|
"""Swap in a variable used in indirect indexing"""
|
|
if str(old) == str(new):
|
|
return
|
|
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
|
|
|
|
def get_index(self, name):
|
|
return self.indexing[name]
|
|
|
|
def __call__(self, *indices):
|
|
index = list(itertools.chain(*indices))
|
|
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
|
assert all(v not in self.var_ranges for v in index)
|
|
replacements = dict(zip(self.var_ranges.keys(), index))
|
|
self.indexing = {
|
|
name: sympy_subs(expr, replacements)
|
|
for name, expr in self.indexing_exprs.items()
|
|
}
|
|
result = self.root_block()
|
|
self.indexing = None
|
|
return result
|
|
|
|
|
|
class LoopBodyBlock:
|
|
"""
|
|
Captures the body of a Loops subclass into an FX graph.
|
|
In normal cases there will be a 1:1 mapping between LoopBody and
|
|
LoopBodyBlock, hower in the case of ops.masked() the masked out
|
|
operations will manifest as an extra LoopBodyBlock.
|
|
"""
|
|
|
|
def __init__(self, body: LoopBody, fn: Callable, args: List[Any]):
|
|
self.body = body
|
|
|
|
def add_index(expr, category, buf_name=None):
|
|
return tracer.create_proxy(
|
|
"call_module",
|
|
"get_index",
|
|
(self.body.add_index_expr(expr, category, buf_name),),
|
|
{},
|
|
)
|
|
|
|
class CaptureIndexing(V.WrapperHandler):
|
|
self.name = "CaptureIndexing"
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
index = add_index(index, "reads", name)
|
|
return self._inner.load(name, index)
|
|
|
|
def store(self, name, index, value, mode=None):
|
|
index = add_index(index, "writes", name)
|
|
return self._inner.store(name, index, value, mode)
|
|
|
|
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
|
index = add_index(index, "writes", name)
|
|
return self._inner.reduction(
|
|
name, dtype, src_dtype, reduction_type, index, value
|
|
)
|
|
|
|
def index_expr(self, index, dtype):
|
|
if isinstance(index, (int, sympy.Integer)):
|
|
return ops.constant(int(index), dtype)
|
|
index = add_index(index, "other")
|
|
return self._inner.index_expr(index, dtype)
|
|
|
|
@staticmethod
|
|
def masked(mask_proxy, masked_body: Callable, other_proxy):
|
|
"""
|
|
Recursively capture the masked out body in another LoopBodyBlock
|
|
"""
|
|
|
|
def shim(mask, other):
|
|
return V.ops.masked(mask, subblock, other)
|
|
|
|
name = self.body.add_submodule(shim, "masked_subblock")
|
|
subblock = LoopBodyBlock(self.body, masked_body, [])
|
|
self.body.subblocks[name] = subblock
|
|
return tracer.create_proxy(
|
|
"call_module", name, (mask_proxy, other_proxy), {}
|
|
)
|
|
|
|
@staticmethod
|
|
def indirect_indexing(index_proxy):
|
|
"""
|
|
Flow data from tensors into indexing formulas.
|
|
Introduce a call_module to update the indexing.
|
|
"""
|
|
|
|
def set_indirect(new_var):
|
|
self.body.replace_indirect(var, V.ops.indirect_indexing(new_var))
|
|
|
|
var = self.body.add_indirect()
|
|
tracer.create_proxy(
|
|
"call_module",
|
|
self.body.add_submodule(set_indirect, f"set_{var}"),
|
|
(index_proxy,),
|
|
{},
|
|
)
|
|
return var
|
|
|
|
tracer = torch.fx.Tracer()
|
|
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
|
|
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
|
from .sizevars import SimplifyIndexing
|
|
|
|
with V.set_ops_handler(
|
|
SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
|
|
):
|
|
tracer.create_proxy("output", "output", (fn(*args),), {})
|
|
self.graph = tracer.graph
|
|
|
|
def __call__(self):
|
|
graph = self.graph
|
|
submodules = self.body.submodules
|
|
|
|
return InterpreterShim(graph, submodules).run(V.get_ops_handler())
|
|
|
|
def debug_str(self, name="block"):
|
|
code = torch.fx.GraphModule(self.body.submodules, self.graph).code
|
|
return re.sub(
|
|
# strip `; del var0` suffixes to make output prettier
|
|
r";[^\n]*",
|
|
"",
|
|
code.strip().replace("def forward(", f"def {name}("),
|
|
)
|
|
|
|
|
|
class Wait(ExternKernelAlloc):
|
|
"""
|
|
Wait should not be used by itself. It should always be constructed in tandem
|
|
with a collective op that produces a work to wait on.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layout,
|
|
inputs,
|
|
constant_args=(),
|
|
):
|
|
super().__init__(layout, inputs, constant_args)
|
|
|
|
def should_allocate(self):
|
|
return False
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.add_import_once(
|
|
"from torch.distributed._functional_collectives import _wait_tensor"
|
|
)
|
|
(input_collective,) = [t.codegen_reference() for t in self.inputs]
|
|
wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})")
|
|
|
|
# wait op still needs to produce a 'buffer' that represents the tensor output.
|
|
# this is a symbolic gesture, and it gets handled by WrapperCodegen.
|
|
# codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective')
|
|
# to a new name (`self.get_name()`) and `del`s the old name.
|
|
wrapper.writeline(f"{self.get_name()} = {input_collective}")
|
|
|
|
@classmethod
|
|
def create(cls, collective_op: "TensorBox"):
|
|
# TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream
|
|
collective_op.decide_layout()
|
|
return Wait(
|
|
layout=collective_op.get_layout(),
|
|
inputs=[collective_op],
|
|
)
|
|
|
|
def get_alias_names(self):
|
|
# Signal to codegen that our output buffer isn't safe to reuse
|
|
return [self.inputs[0].codegen_reference()]
|
|
|
|
|
|
class CollectiveKernel(ExternKernel):
|
|
"""
|
|
Each CollectiveKernel should follow the patterns
|
|
- it writes into a given output buffer
|
|
- the kernel delegates into c10d processgroup, which returns a 'work' obj
|
|
- the work obj is registered via _register_tensor_work so it can be waited on later
|
|
"""
|
|
|
|
def __init__(self, layout, inputs, constant_args):
|
|
super().__init__(None, layout, inputs, constant_args)
|
|
self.name = V.graph.register_buffer(self)
|
|
|
|
def should_allocate(self):
|
|
return True
|
|
|
|
def codegen_collective(self, wrapper, output_name, input_names):
|
|
# factor so the boilerplate can be handled in CollectiveKernel.codegen
|
|
raise NotImplementedError("Must implement")
|
|
|
|
def codegen(self, wrapper):
|
|
wrapper.add_import_once("import torch.distributed as dist")
|
|
wrapper.add_import_once(
|
|
"from torch.distributed._functional_collectives import _str_to_reduce_op, _register_tensor_work"
|
|
)
|
|
wrapper.add_import_once(
|
|
"from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag"
|
|
)
|
|
|
|
# extract references to our args in string form for codegen output
|
|
input_names = [t.codegen_reference() for t in self.inputs]
|
|
output_name = self.get_name()
|
|
tag, ranks, group_size = self.constant_args
|
|
|
|
# TODO: avoid more than one ref of the same pg (even though they are cached inside the api)
|
|
wrapper.writeline(
|
|
f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
|
|
)
|
|
|
|
self.codegen_collective(wrapper, output_name, input_names)
|
|
|
|
wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)")
|
|
|
|
|
|
class AllReduce(CollectiveKernel):
|
|
def __init__(self, layout, inputs, constant_args, reduce_op):
|
|
super().__init__(layout, inputs, constant_args)
|
|
self.reduce_op = reduce_op
|
|
|
|
@classmethod
|
|
def create(
|
|
cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int
|
|
):
|
|
x = cls.realize_input(x)
|
|
|
|
# is there a difference between literally using x.data.layout below, vs
|
|
# creating a new one that has the same properties?
|
|
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), x.get_size())
|
|
|
|
return AllReduce(
|
|
layout=new_layout,
|
|
inputs=[x],
|
|
constant_args=[tag, ranks, group_size],
|
|
reduce_op=reduce_op,
|
|
)
|
|
|
|
def codegen_collective(self, wrapper, output_name, input_names):
|
|
# We must copy our input buffer sometimes, but the scheduler will help us find opportunities
|
|
# to reuse the input buffer. (This requires no other users of the input buffer.)
|
|
if not wrapper.did_reuse(self, self.inputs[0]):
|
|
wrapper.writeline(f"{output_name}.copy_({input_names[0]})")
|
|
|
|
# At this point, output_name points to a buffer that is either
|
|
# (1) the input buffer, which we're allowed to inplace modify
|
|
# (2) a freshly allocated buffer, which we've copied the input into above
|
|
wrapper.writeline(
|
|
f"{output_name}_work = dist.all_reduce("
|
|
f"{output_name}, async_op=True, group={output_name}_pg, op=_str_to_reduce_op('{str(self.reduce_op)}'))"
|
|
)
|
|
|
|
|
|
class AllGatherIntoTensor(CollectiveKernel):
|
|
def __init__(self, layout, inputs, constant_args):
|
|
super().__init__(layout, inputs, constant_args)
|
|
|
|
@classmethod
|
|
def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int):
|
|
x = cls.realize_input(x)
|
|
|
|
# is there a difference between literally using x.data.layout below, vs
|
|
# creating a new one that has the same properties?
|
|
new_size = x.get_size()
|
|
new_size[0] *= group_size
|
|
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size)
|
|
|
|
# AllReduce returns a 'work' object. But Inductor's scheduler doesn't need to know
|
|
# about that, and we just pretend for scheduling purposes that the work obj is a 1-elem tensor.
|
|
# Nobody should consume the output of AllReduce except 'Wait', which we control here.
|
|
return AllGatherIntoTensor(
|
|
layout=new_layout,
|
|
inputs=[x],
|
|
constant_args=[tag, ranks, group_size],
|
|
)
|
|
|
|
def codegen_collective(self, wrapper, output_name, input_names):
|
|
wrapper.writeline(
|
|
f"{output_name}_work = dist.all_gather_into_tensor("
|
|
f"{output_name}, {input_names[0]}, async_op=True, group={output_name}_pg)"
|
|
)
|
|
|
|
# At this point, output_name points to a fresh buffer
|
|
wrapper.writeline(
|
|
f"{output_name}_work = dist.all_gather_into_tensor({output_name}, {input_names[0]}, async_op=True,"
|
|
f" group={output_name}_pg)"
|
|
)
|
|
wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)")
|
|
|
|
|
|
class ReduceScatterTensor(CollectiveKernel):
|
|
def __init__(self, layout, inputs, constant_args, reduce_op):
|
|
super().__init__(layout, inputs, constant_args)
|
|
self.reduce_op = reduce_op
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
x: "TensorBox",
|
|
reduce_op: str,
|
|
tag: str,
|
|
ranks: List[int],
|
|
group_size: int,
|
|
):
|
|
x = cls.realize_input(x)
|
|
|
|
# is there a difference between literally using x.data.layout below, vs
|
|
# creating a new one that has the same properties?
|
|
new_size = x.get_size()
|
|
new_size[0] /= group_size
|
|
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size)
|
|
|
|
return ReduceScatterTensor(
|
|
layout=new_layout,
|
|
inputs=[x],
|
|
constant_args=[tag, ranks, group_size],
|
|
reduce_op=reduce_op,
|
|
)
|
|
|
|
def codegen_collective(self, wrapper, output_name, input_names):
|
|
wrapper.writeline(
|
|
f"{output_name}_work = dist.reduce_scatter_tensor("
|
|
f"{output_name}, {input_names[0]}, "
|
|
f"async_op=True, group={output_name}_pg, op=_str_to_reduce_op('{str(self.reduce_op)}'))"
|
|
)
|