Files
pytorch/torch/_prims/nvfuser_prims.py
2023-03-15 06:27:59 +00:00

841 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Module for defining "primitive" operations executable by the nvFuser. This
# list exists to decouple main set of primitives from the ones that provide a
# lowering of the op to nvFusers Python interface. Mostly torch.ops.nvprims is
# a subset of the primitives in torch.ops.prims, but some additional primitives
# can be added in the future for the corresponding higher-level torch/aten
# functions.
from typing import Any, Dict, Optional, Tuple
import torch
import torch._prims_common as utils
from torch._prims_common import (
DimsSequenceType,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
getnvFuserDtype,
make_contiguous_strides_for,
NumberType,
ShapeType,
TensorLikeType,
)
from torch._prims_common.wrappers import (
_maybe_convert_to_dtype,
backwards_not_supported,
elementwise_type_promotion_wrapper,
)
nvprim_namespace = "nvprims"
nvprim = torch.library.Library(nvprim_namespace, "DEF")
nvprim_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
)
nvprim_implicit_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
)
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
nvprim_names = [
"abs",
"acos",
"asin",
"atan",
"atanh",
"cos",
"cosh",
"clone",
"bitwise_not",
"ceil",
"erf",
"erfc",
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"real",
"reciprocal",
"neg",
"round",
"rsqrt",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"transpose",
"trunc",
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"div",
"eq",
"fmod",
"ge",
"gt",
"le",
"lt",
"mul",
"ne",
"pow",
"remainder",
"sub",
"squeeze",
"view_of",
"broadcast_in_dim",
"where",
"convert_element_type",
"sum",
"var",
"amax",
"amin",
]
_nvfuser_impls: Dict[str, Any] = {}
_nvfuser_unary_ops = {
"abs",
"acos",
"asin",
"atan",
"atanh",
"cos",
"cosh",
"bitwise_not",
"ceil",
"erf",
"erfc",
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"reciprocal",
"neg",
"real",
"round",
"rsqrt",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"trunc",
}
def _assert_nvfuser_op_exists(fname: str):
try:
try:
from nvfuser import ( # type: ignore[import, attr-defined]
FusionDefinition as fd,
)
except ImportError:
from nvfuser._C import FusionDefinition as fd # type: ignore[import]
assert getattr(fd.Operators, fname)
except ImportError:
# Not all PyTorch builds have nvfuser
pass
for fname in _nvfuser_unary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a):
return fd.ops.{fname}(a) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
_nvfuser_binary_ops = {
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"div",
"eq",
"fmod",
"ge",
"gt",
"le",
"lt",
"mul",
"ne",
"pow",
"remainder",
"sub",
}
for fname in _nvfuser_binary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a, b):
return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
_nvfuser_ternary_ops = {
"where",
}
for fname in _nvfuser_ternary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a, b, c):
return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
def _native_batch_norm_nvfuser(
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
):
"""
if weight is None:
weight = fd.define_null_tensor()
if bias is None:
bias = fd.define_null_tensor()
if running_mean is None:
running_mean = fd.define_null_tensor()
if running_var is None:
running_var = fd.define_null_tensor()
"""
return fd.ops.batch_norm(
input,
weight,
bias,
running_mean,
running_var,
momentum,
eps,
training,
)
def _broadcast_in_dim_nvfuser(
fd: Any,
a: TensorLikeType,
shape: ShapeType,
broadcast_dimensions: ShapeType,
):
return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
def _transpose_nvfuser(fd, a, dims):
return fd.ops.permute(a, dims) # type: ignore[attr-defined]
def _squeeze_nvfuser(fd, a, a_shape, dimensions):
for idx in sorted(dimensions, reverse=True):
a = fd.ops.squeeze(a, a_shape, idx)
a_shape = a_shape[:idx] + a_shape[idx + 1 :]
return a
def _view_of_nvfuser(fd, a):
return fd.ops.set(a)
def _view_nvfuser(
fd,
a,
a_shape,
new_shape,
):
try:
return fd.ops.view(a, a_shape, new_shape)
except AttributeError:
return fd.ops.reshape(a, a_shape, new_shape)
def _sum_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]
output_dtype = DataType.Null
return fd.ops.sum(a, dims, keep_dims, output_dtype)
def _var_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
*,
correction: float,
):
keep_dims = False
return fd.ops.var(a, dims, correction, keep_dims)
def _var_mean_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: float,
):
# Unbiased arg shouldn't be set when this function is called
assert unbiased is None
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
# keepdim is handled by the reference implementation
keepdim = False
return fd.ops.var_mean(a, dims, correction, keepdim)
def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
return fd.ops.rand_like(a)
def _amax_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
return fd.ops.max(a, dims, keep_dims)
def _amin_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
return fd.ops.min(a, dims, keep_dims)
def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
return fd.ops.set(input)
def _full_nvfuser(
fd: Any,
shape: ShapeType,
fill_value: NumberType,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
):
assert device != torch.device("cpu")
assert layout is None or layout is torch.strided
assert pin_memory is False
assert requires_grad is False
dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.ops.full(shape, fill_value, nvfuser_dtype)
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
_nvfuser_impls["clone"] = _clone_nvfuser
_nvfuser_impls["transpose"] = _transpose_nvfuser
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
_nvfuser_impls["view_of"] = _view_of_nvfuser
_nvfuser_impls["view"] = _view_nvfuser
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
_nvfuser_impls["sum"] = _sum_nvfuser
_nvfuser_impls["var"] = _var_nvfuser
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
_nvfuser_impls["amax"] = _amax_nvfuser
_nvfuser_impls["amin"] = _amin_nvfuser
_nvfuser_impls["full"] = _full_nvfuser
def register_full():
name = "full"
nvprim.define(
"full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
+ "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
)
def _meta_impl(
size,
fill_value,
*,
out=None,
dtype=None,
layout=None,
device=None,
pin_memory=False,
requires_grad=False,
):
strides = make_contiguous_strides_for(size)
return torch._prims.TensorMeta(
None,
shape=size,
strides=strides,
dtype=dtype,
device=device,
)
def _prim_impl(
size,
fill_value,
*,
out=None,
dtype=None,
layout=None,
device=None,
pin_memory=False,
requires_grad=False,
):
return torch.full(
size,
fill_value,
out=out,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_impl)
prim_packet = getattr(torch._ops.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Create a tensor with given size and filled with value"
p.impl_nvfuser = _nvfuser_impls["full"]
p.is_recomputable = _nvfuser_is_recomputable["full"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
# functorch.compile.min_cut_rematerialization_partition accepts a list of
# operators that can be recomputed in the backward pass. This list is used to
# determine which operators can be recomputed. If an operator is not in this
# list, it will not be recomputed.
_nvfuser_is_recomputable: Dict[str, bool] = {
# Reductions are not allowed to be recomputed
"amax": False,
"amin": False,
"sum": False,
"var": False,
"var_mean": False,
# Normalizations are not allowed to be recomputed
"native_batch_norm": False,
# Random ops are not allowed to be recomputed
"rand_like": False,
# Everything else is allowed to be recomputed
"abs": True,
"acos": True,
"add": True,
"asin": True,
"atan": True,
"atan2": True,
"atanh": True,
"bitwise_and": True,
"bitwise_not": True,
"bitwise_or": True,
"bitwise_xor": True,
"broadcast_in_dim": True,
"ceil": True,
"clone": True,
"convert_element_type": True,
"cos": True,
"cosh": True,
"div": True,
"eq": True,
"erf": True,
"erfc": True,
"exp": True,
"expm1": True,
"floor": True,
"fmod": True,
"full": True,
"ge": True,
"gt": True,
"imag": True,
"isfinite": True,
"le": True,
"lgamma": True,
"log": True,
"log10": True,
"log1p": True,
"log2": True,
"lt": True,
"mul": True,
"ne": True,
"neg": True,
"pow": True,
"real": True,
"reciprocal": True,
"remainder": True,
"round": True,
"rsqrt": True,
"sign": True,
"sin": True,
"sinh": True,
"sqrt": True,
"squeeze": True,
"sub": True,
"tan": True,
"tanh": True,
"transpose": True,
"trunc": True,
"view": True,
"view_of": True,
"where": True,
}
def register_native_batch_norm():
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
name = "native_batch_norm"
nvprim.define(
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
+ "bool training, float momentum, float eps)"
+ " -> (Tensor, Tensor, Tensor)"
)
def _prim_impl(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.native_batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
nvprim_impl.impl(name, _prim_impl)
prim_packet = torch._ops.ops.nvprims.native_batch_norm
prim = prim_packet.default
def _native_batch_norm_ref(
input: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
training: bool,
momentum: float,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if torch._prims_common.is_complex_dtype(input.dtype):
raise NotImplementedError("Complex tensors are not supported")
# note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
result_dtype = input.dtype
computation_dtype, _ = elementwise_dtypes(
input,
weight,
bias,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
input_ = _maybe_convert_to_dtype(input, computation_dtype)
output, mean, rstd = prim(
input_, weight, bias, running_mean, running_var, training, momentum, eps
)
output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
return (output_, mean, rstd) # type: ignore[return-value]
def _native_batch_norm_autograd(
input: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
training: bool,
momentum: float,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# This wrapper is needed to convert prims calls inside
# _native_batch_norm_ref to nvprims calls
from torch._prims.context import NvfuserPrimsMode
with NvfuserPrimsMode():
return backwards_not_supported(_native_batch_norm_ref)(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
for p in (prim_packet, prim):
p.__doc__ = "Computes batch normalization."
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_rand_like():
name = "rand_like"
nvprim.define(
"rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
+ "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
)
def _meta_rand_like(
self,
*,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
strides = make_contiguous_strides_for(self.shape)
return torch._prims.TensorMeta(
self,
shape=self.shape,
strides=strides,
dtype=dtype,
device=device,
)
def _prim_impl(
self,
*,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
return torch.rand_like(
self,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
memory_format=memory_format,
)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_rand_like)
prim_packet = getattr(torch._ops.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Computes rand_like"
p.impl_nvfuser = _nvfuser_impls["rand_like"]
p.is_recomputable = _nvfuser_is_recomputable["rand_like"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_var_mean():
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
name = "var_mean.main"
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
nvprim.define(
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, float? correction=None)"
+ " -> (Tensor, Tensor)"
)
# This function is used for device="meta" Tensors.
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
if torch._prims_common.is_complex_dtype(inp.dtype):
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
else:
output_dtype = inp.dtype
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
if keepdim:
output_shape = [
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
]
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
var = torch._ops.ops.nvprims.broadcast_in_dim(
var, output_shape, broadcast_dims
)
mean = torch._ops.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
return (var, mean)
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_var_mean)
prim_packet = torch._ops.ops.nvprims.var_mean
prim = prim_packet.main
def _unbiased_overload_impl(inp, unbiased):
return prim(inp, dim=None, unbiased=unbiased)
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
)
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
dim = torch._prims_common.reduction_dims(a.shape, dim)
# For complex tensors eager computes the variance as the sum of variances of
# the real and imaginary parts
# TODO: Creating a complex tensor from real and imaginary parts is not supported
if torch._prims_common.is_complex_dtype(a.dtype):
raise NotImplementedError("Complex tensors are not supported")
var_mean = prim(a, dim, correction=correction)
if keepdim:
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
var, mean = var_mean
var = torch._ops.ops.nvprims.broadcast_in_dim(
var, output_shape, broadcast_dims
)
mean = torch._ops.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
var_mean = (var, mean)
return var_mean
def _var_mean_autograd(
a, dim=None, unbiased=None, keepdim=False, *, correction=None
):
# This wrapper is needed to convert prims calls inside
# elementwise_type_promotion_wrapper to nvprims calls
from torch._prims.context import NvfuserPrimsMode
with NvfuserPrimsMode():
return backwards_not_supported(_var_mean_ref)(
a, dim, unbiased, keepdim, correction=correction
)
nvprim_autograd_impl.impl(name, _var_mean_autograd)
for p in (prim_packet, prim):
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
p.impl_nvfuser = _nvfuser_impls["var_mean"]
p.is_recomputable = _nvfuser_is_recomputable["var_mean"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def _nvprims_view_impl_aten(a, original_shape, new_shape):
return a.reshape(new_shape)
def register_view():
"""This function is used to register the view function in torch.ops.view module."""
# View is implemented as a decomposition into prims.split_dim,
# prims.collapse_dim, and prims.reshape, but we would like to intercept
# non-decomposed view for now
name = "view"
nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(a, original_shape, new_shape):
return a.reshape(new_shape)
nvprim_impl.impl(name, _prim_impl)
prim_packet = torch._ops.ops.nvprims.view
prim = prim_packet.default
def _view_no_original_shape_overload_impl(a, shape):
if list(a.shape) == list(shape):
return torch.ops.nvprims.view_of(a)
return torch.ops.nvprims.view.default(a, a.shape, shape)
nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
p.impl_nvfuser = _nvfuser_impls["view"]
p.is_recomputable = _nvfuser_is_recomputable["view"]
p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined]
p.impl_aten = _nvprims_view_impl_aten
def register_nvprims():
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
register_var_mean()
register_view()
register_native_batch_norm()
register_rand_like()
register_full()
for name in nvprim_names:
main_prim = getattr(torch._ops.ops.prims, name)
nvprim.define(main_prim.schema)
nvprim_impl.impl(name, main_prim.prim_impl)
nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
prim_packet = getattr(torch._ops.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = main_prim.__doc__
p.impl_nvfuser = _nvfuser_impls[name]
p.is_recomputable = _nvfuser_is_recomputable.get(name, False)
p.return_type = main_prim.return_type # type: ignore[attr-defined]
p.impl_aten = main_prim.impl_aten