mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578 Approved by: https://github.com/ezyang
841 lines
23 KiB
Python
841 lines
23 KiB
Python
# Module for defining "primitive" operations executable by the nvFuser. This
|
||
# list exists to decouple main set of primitives from the ones that provide a
|
||
# lowering of the op to nvFuser’s Python interface. Mostly torch.ops.nvprims is
|
||
# a subset of the primitives in torch.ops.prims, but some additional primitives
|
||
# can be added in the future for the corresponding higher-level torch/aten
|
||
# functions.
|
||
|
||
from typing import Any, Dict, Optional, Tuple
|
||
|
||
import torch
|
||
import torch._prims_common as utils
|
||
|
||
from torch._prims_common import (
|
||
DimsSequenceType,
|
||
elementwise_dtypes,
|
||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||
getnvFuserDtype,
|
||
make_contiguous_strides_for,
|
||
NumberType,
|
||
ShapeType,
|
||
TensorLikeType,
|
||
)
|
||
|
||
from torch._prims_common.wrappers import (
|
||
_maybe_convert_to_dtype,
|
||
backwards_not_supported,
|
||
elementwise_type_promotion_wrapper,
|
||
)
|
||
|
||
nvprim_namespace = "nvprims"
|
||
nvprim = torch.library.Library(nvprim_namespace, "DEF")
|
||
nvprim_impl = torch.library.Library(
|
||
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
|
||
)
|
||
nvprim_implicit_impl = torch.library.Library(
|
||
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
|
||
)
|
||
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
|
||
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
|
||
|
||
nvprim_names = [
|
||
"abs",
|
||
"acos",
|
||
"asin",
|
||
"atan",
|
||
"atanh",
|
||
"cos",
|
||
"cosh",
|
||
"clone",
|
||
"bitwise_not",
|
||
"ceil",
|
||
"erf",
|
||
"erfc",
|
||
"exp",
|
||
"expm1",
|
||
"floor",
|
||
"imag",
|
||
"isfinite",
|
||
"lgamma",
|
||
"log",
|
||
"log1p",
|
||
"log2",
|
||
"log10",
|
||
"real",
|
||
"reciprocal",
|
||
"neg",
|
||
"round",
|
||
"rsqrt",
|
||
"sign",
|
||
"sin",
|
||
"sinh",
|
||
"sqrt",
|
||
"tan",
|
||
"tanh",
|
||
"transpose",
|
||
"trunc",
|
||
"add",
|
||
"atan2",
|
||
"bitwise_and",
|
||
"bitwise_or",
|
||
"bitwise_xor",
|
||
"div",
|
||
"eq",
|
||
"fmod",
|
||
"ge",
|
||
"gt",
|
||
"le",
|
||
"lt",
|
||
"mul",
|
||
"ne",
|
||
"pow",
|
||
"remainder",
|
||
"sub",
|
||
"squeeze",
|
||
"view_of",
|
||
"broadcast_in_dim",
|
||
"where",
|
||
"convert_element_type",
|
||
"sum",
|
||
"var",
|
||
"amax",
|
||
"amin",
|
||
]
|
||
|
||
_nvfuser_impls: Dict[str, Any] = {}
|
||
|
||
_nvfuser_unary_ops = {
|
||
"abs",
|
||
"acos",
|
||
"asin",
|
||
"atan",
|
||
"atanh",
|
||
"cos",
|
||
"cosh",
|
||
"bitwise_not",
|
||
"ceil",
|
||
"erf",
|
||
"erfc",
|
||
"exp",
|
||
"expm1",
|
||
"floor",
|
||
"imag",
|
||
"isfinite",
|
||
"lgamma",
|
||
"log",
|
||
"log1p",
|
||
"log2",
|
||
"log10",
|
||
"reciprocal",
|
||
"neg",
|
||
"real",
|
||
"round",
|
||
"rsqrt",
|
||
"sign",
|
||
"sin",
|
||
"sinh",
|
||
"sqrt",
|
||
"tan",
|
||
"tanh",
|
||
"trunc",
|
||
}
|
||
|
||
|
||
def _assert_nvfuser_op_exists(fname: str):
|
||
try:
|
||
try:
|
||
from nvfuser import ( # type: ignore[import, attr-defined]
|
||
FusionDefinition as fd,
|
||
)
|
||
except ImportError:
|
||
from nvfuser._C import FusionDefinition as fd # type: ignore[import]
|
||
|
||
assert getattr(fd.Operators, fname)
|
||
except ImportError:
|
||
# Not all PyTorch builds have nvfuser
|
||
pass
|
||
|
||
|
||
for fname in _nvfuser_unary_ops:
|
||
exec(
|
||
f"""
|
||
# Ensure that the nvfuser implementation exists
|
||
_assert_nvfuser_op_exists("{fname}")
|
||
|
||
def _{fname}_nvfuser(fd, a):
|
||
return fd.ops.{fname}(a) # type: ignore[attr-defined]
|
||
|
||
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||
"""
|
||
)
|
||
|
||
_nvfuser_binary_ops = {
|
||
"add",
|
||
"atan2",
|
||
"bitwise_and",
|
||
"bitwise_or",
|
||
"bitwise_xor",
|
||
"div",
|
||
"eq",
|
||
"fmod",
|
||
"ge",
|
||
"gt",
|
||
"le",
|
||
"lt",
|
||
"mul",
|
||
"ne",
|
||
"pow",
|
||
"remainder",
|
||
"sub",
|
||
}
|
||
|
||
for fname in _nvfuser_binary_ops:
|
||
exec(
|
||
f"""
|
||
# Ensure that the nvfuser implementation exists
|
||
_assert_nvfuser_op_exists("{fname}")
|
||
|
||
def _{fname}_nvfuser(fd, a, b):
|
||
return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
|
||
|
||
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||
"""
|
||
)
|
||
|
||
_nvfuser_ternary_ops = {
|
||
"where",
|
||
}
|
||
|
||
for fname in _nvfuser_ternary_ops:
|
||
exec(
|
||
f"""
|
||
# Ensure that the nvfuser implementation exists
|
||
_assert_nvfuser_op_exists("{fname}")
|
||
|
||
def _{fname}_nvfuser(fd, a, b, c):
|
||
return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
|
||
|
||
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||
"""
|
||
)
|
||
|
||
|
||
def _native_batch_norm_nvfuser(
|
||
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
|
||
):
|
||
"""
|
||
if weight is None:
|
||
weight = fd.define_null_tensor()
|
||
if bias is None:
|
||
bias = fd.define_null_tensor()
|
||
if running_mean is None:
|
||
running_mean = fd.define_null_tensor()
|
||
if running_var is None:
|
||
running_var = fd.define_null_tensor()
|
||
"""
|
||
return fd.ops.batch_norm(
|
||
input,
|
||
weight,
|
||
bias,
|
||
running_mean,
|
||
running_var,
|
||
momentum,
|
||
eps,
|
||
training,
|
||
)
|
||
|
||
|
||
def _broadcast_in_dim_nvfuser(
|
||
fd: Any,
|
||
a: TensorLikeType,
|
||
shape: ShapeType,
|
||
broadcast_dimensions: ShapeType,
|
||
):
|
||
return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
|
||
|
||
|
||
def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
|
||
nvfuser_dtype = getnvFuserDtype(dtype)
|
||
return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
|
||
|
||
|
||
def _transpose_nvfuser(fd, a, dims):
|
||
return fd.ops.permute(a, dims) # type: ignore[attr-defined]
|
||
|
||
|
||
def _squeeze_nvfuser(fd, a, a_shape, dimensions):
|
||
for idx in sorted(dimensions, reverse=True):
|
||
a = fd.ops.squeeze(a, a_shape, idx)
|
||
a_shape = a_shape[:idx] + a_shape[idx + 1 :]
|
||
return a
|
||
|
||
|
||
def _view_of_nvfuser(fd, a):
|
||
return fd.ops.set(a)
|
||
|
||
|
||
def _view_nvfuser(
|
||
fd,
|
||
a,
|
||
a_shape,
|
||
new_shape,
|
||
):
|
||
try:
|
||
return fd.ops.view(a, a_shape, new_shape)
|
||
except AttributeError:
|
||
return fd.ops.reshape(a, a_shape, new_shape)
|
||
|
||
|
||
def _sum_nvfuser(
|
||
fd: Any,
|
||
a: TensorLikeType,
|
||
dims: DimsSequenceType,
|
||
):
|
||
keep_dims = False
|
||
try:
|
||
from nvfuser import DataType # type: ignore[import, attr-defined]
|
||
except ImportError:
|
||
from nvfuser._C import DataType # type: ignore[import]
|
||
|
||
output_dtype = DataType.Null
|
||
return fd.ops.sum(a, dims, keep_dims, output_dtype)
|
||
|
||
|
||
def _var_nvfuser(
|
||
fd: Any,
|
||
a: TensorLikeType,
|
||
dims: DimsSequenceType,
|
||
*,
|
||
correction: float,
|
||
):
|
||
keep_dims = False
|
||
return fd.ops.var(a, dims, correction, keep_dims)
|
||
|
||
|
||
def _var_mean_nvfuser(
|
||
fd: Any,
|
||
a: TensorLikeType,
|
||
dims: DimsSequenceType,
|
||
unbiased: Optional[bool] = None,
|
||
keepdim: bool = False,
|
||
*,
|
||
correction: float,
|
||
):
|
||
# Unbiased arg shouldn't be set when this function is called
|
||
assert unbiased is None
|
||
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
|
||
# keepdim is handled by the reference implementation
|
||
keepdim = False
|
||
return fd.ops.var_mean(a, dims, correction, keepdim)
|
||
|
||
|
||
def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
|
||
return fd.ops.rand_like(a)
|
||
|
||
|
||
def _amax_nvfuser(
|
||
fd: Any,
|
||
a: TensorLikeType,
|
||
dims: DimsSequenceType,
|
||
):
|
||
keep_dims = False
|
||
return fd.ops.max(a, dims, keep_dims)
|
||
|
||
|
||
def _amin_nvfuser(
|
||
fd: Any,
|
||
a: TensorLikeType,
|
||
dims: DimsSequenceType,
|
||
):
|
||
keep_dims = False
|
||
return fd.ops.min(a, dims, keep_dims)
|
||
|
||
|
||
def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
|
||
return fd.ops.set(input)
|
||
|
||
|
||
def _full_nvfuser(
|
||
fd: Any,
|
||
shape: ShapeType,
|
||
fill_value: NumberType,
|
||
*,
|
||
dtype: Optional[torch.dtype] = None,
|
||
layout: Optional[torch.layout] = None,
|
||
device: Optional[torch.device] = None,
|
||
pin_memory: bool = False,
|
||
requires_grad: bool = False,
|
||
):
|
||
assert device != torch.device("cpu")
|
||
assert layout is None or layout is torch.strided
|
||
assert pin_memory is False
|
||
assert requires_grad is False
|
||
dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
|
||
nvfuser_dtype = getnvFuserDtype(dtype)
|
||
return fd.ops.full(shape, fill_value, nvfuser_dtype)
|
||
|
||
|
||
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
|
||
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
|
||
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
|
||
_nvfuser_impls["clone"] = _clone_nvfuser
|
||
_nvfuser_impls["transpose"] = _transpose_nvfuser
|
||
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
|
||
_nvfuser_impls["view_of"] = _view_of_nvfuser
|
||
_nvfuser_impls["view"] = _view_nvfuser
|
||
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
|
||
_nvfuser_impls["sum"] = _sum_nvfuser
|
||
_nvfuser_impls["var"] = _var_nvfuser
|
||
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
|
||
_nvfuser_impls["amax"] = _amax_nvfuser
|
||
_nvfuser_impls["amin"] = _amin_nvfuser
|
||
_nvfuser_impls["full"] = _full_nvfuser
|
||
|
||
|
||
def register_full():
|
||
name = "full"
|
||
|
||
nvprim.define(
|
||
"full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
|
||
+ "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
|
||
)
|
||
|
||
def _meta_impl(
|
||
size,
|
||
fill_value,
|
||
*,
|
||
out=None,
|
||
dtype=None,
|
||
layout=None,
|
||
device=None,
|
||
pin_memory=False,
|
||
requires_grad=False,
|
||
):
|
||
strides = make_contiguous_strides_for(size)
|
||
return torch._prims.TensorMeta(
|
||
None,
|
||
shape=size,
|
||
strides=strides,
|
||
dtype=dtype,
|
||
device=device,
|
||
)
|
||
|
||
def _prim_impl(
|
||
size,
|
||
fill_value,
|
||
*,
|
||
out=None,
|
||
dtype=None,
|
||
layout=None,
|
||
device=None,
|
||
pin_memory=False,
|
||
requires_grad=False,
|
||
):
|
||
return torch.full(
|
||
size,
|
||
fill_value,
|
||
out=out,
|
||
dtype=dtype,
|
||
layout=layout,
|
||
device=device,
|
||
pin_memory=pin_memory,
|
||
requires_grad=requires_grad,
|
||
)
|
||
|
||
nvprim_impl.impl(name, _prim_impl)
|
||
nvprim_meta_impl.impl(name, _meta_impl)
|
||
|
||
prim_packet = getattr(torch._ops.ops.nvprims, name)
|
||
prim = prim_packet.default
|
||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||
for p in (prim_packet, prim):
|
||
p.__doc__ = "Create a tensor with given size and filled with value"
|
||
p.impl_nvfuser = _nvfuser_impls["full"]
|
||
p.is_recomputable = _nvfuser_is_recomputable["full"]
|
||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||
|
||
|
||
# functorch.compile.min_cut_rematerialization_partition accepts a list of
|
||
# operators that can be recomputed in the backward pass. This list is used to
|
||
# determine which operators can be recomputed. If an operator is not in this
|
||
# list, it will not be recomputed.
|
||
_nvfuser_is_recomputable: Dict[str, bool] = {
|
||
# Reductions are not allowed to be recomputed
|
||
"amax": False,
|
||
"amin": False,
|
||
"sum": False,
|
||
"var": False,
|
||
"var_mean": False,
|
||
# Normalizations are not allowed to be recomputed
|
||
"native_batch_norm": False,
|
||
# Random ops are not allowed to be recomputed
|
||
"rand_like": False,
|
||
# Everything else is allowed to be recomputed
|
||
"abs": True,
|
||
"acos": True,
|
||
"add": True,
|
||
"asin": True,
|
||
"atan": True,
|
||
"atan2": True,
|
||
"atanh": True,
|
||
"bitwise_and": True,
|
||
"bitwise_not": True,
|
||
"bitwise_or": True,
|
||
"bitwise_xor": True,
|
||
"broadcast_in_dim": True,
|
||
"ceil": True,
|
||
"clone": True,
|
||
"convert_element_type": True,
|
||
"cos": True,
|
||
"cosh": True,
|
||
"div": True,
|
||
"eq": True,
|
||
"erf": True,
|
||
"erfc": True,
|
||
"exp": True,
|
||
"expm1": True,
|
||
"floor": True,
|
||
"fmod": True,
|
||
"full": True,
|
||
"ge": True,
|
||
"gt": True,
|
||
"imag": True,
|
||
"isfinite": True,
|
||
"le": True,
|
||
"lgamma": True,
|
||
"log": True,
|
||
"log10": True,
|
||
"log1p": True,
|
||
"log2": True,
|
||
"lt": True,
|
||
"mul": True,
|
||
"ne": True,
|
||
"neg": True,
|
||
"pow": True,
|
||
"real": True,
|
||
"reciprocal": True,
|
||
"remainder": True,
|
||
"round": True,
|
||
"rsqrt": True,
|
||
"sign": True,
|
||
"sin": True,
|
||
"sinh": True,
|
||
"sqrt": True,
|
||
"squeeze": True,
|
||
"sub": True,
|
||
"tan": True,
|
||
"tanh": True,
|
||
"transpose": True,
|
||
"trunc": True,
|
||
"view": True,
|
||
"view_of": True,
|
||
"where": True,
|
||
}
|
||
|
||
|
||
def register_native_batch_norm():
|
||
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
|
||
name = "native_batch_norm"
|
||
|
||
nvprim.define(
|
||
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
|
||
+ "bool training, float momentum, float eps)"
|
||
+ " -> (Tensor, Tensor, Tensor)"
|
||
)
|
||
|
||
def _prim_impl(
|
||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||
):
|
||
return torch.native_batch_norm(
|
||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||
)
|
||
|
||
nvprim_impl.impl(name, _prim_impl)
|
||
prim_packet = torch._ops.ops.nvprims.native_batch_norm
|
||
prim = prim_packet.default
|
||
|
||
def _native_batch_norm_ref(
|
||
input: torch.Tensor,
|
||
weight: Optional[torch.Tensor],
|
||
bias: Optional[torch.Tensor],
|
||
running_mean: Optional[torch.Tensor],
|
||
running_var: Optional[torch.Tensor],
|
||
training: bool,
|
||
momentum: float,
|
||
eps: float,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
if torch._prims_common.is_complex_dtype(input.dtype):
|
||
raise NotImplementedError("Complex tensors are not supported")
|
||
|
||
# note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
|
||
result_dtype = input.dtype
|
||
computation_dtype, _ = elementwise_dtypes(
|
||
input,
|
||
weight,
|
||
bias,
|
||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||
)
|
||
|
||
input_ = _maybe_convert_to_dtype(input, computation_dtype)
|
||
output, mean, rstd = prim(
|
||
input_, weight, bias, running_mean, running_var, training, momentum, eps
|
||
)
|
||
output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
|
||
return (output_, mean, rstd) # type: ignore[return-value]
|
||
|
||
def _native_batch_norm_autograd(
|
||
input: torch.Tensor,
|
||
weight: Optional[torch.Tensor],
|
||
bias: Optional[torch.Tensor],
|
||
running_mean: Optional[torch.Tensor],
|
||
running_var: Optional[torch.Tensor],
|
||
training: bool,
|
||
momentum: float,
|
||
eps: float,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
# This wrapper is needed to convert prims calls inside
|
||
# _native_batch_norm_ref to nvprims calls
|
||
from torch._prims.context import NvfuserPrimsMode
|
||
|
||
with NvfuserPrimsMode():
|
||
return backwards_not_supported(_native_batch_norm_ref)(
|
||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||
)
|
||
|
||
nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
|
||
|
||
for p in (prim_packet, prim):
|
||
p.__doc__ = "Computes batch normalization."
|
||
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
|
||
p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"]
|
||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||
|
||
|
||
def register_rand_like():
|
||
name = "rand_like"
|
||
|
||
nvprim.define(
|
||
"rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
|
||
+ "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
|
||
)
|
||
|
||
def _meta_rand_like(
|
||
self,
|
||
*,
|
||
dtype=None,
|
||
layout=None,
|
||
device=None,
|
||
pin_memory=None,
|
||
memory_format=None,
|
||
):
|
||
strides = make_contiguous_strides_for(self.shape)
|
||
return torch._prims.TensorMeta(
|
||
self,
|
||
shape=self.shape,
|
||
strides=strides,
|
||
dtype=dtype,
|
||
device=device,
|
||
)
|
||
|
||
def _prim_impl(
|
||
self,
|
||
*,
|
||
dtype=None,
|
||
layout=None,
|
||
device=None,
|
||
pin_memory=None,
|
||
memory_format=None,
|
||
):
|
||
return torch.rand_like(
|
||
self,
|
||
dtype=dtype,
|
||
layout=layout,
|
||
device=device,
|
||
pin_memory=pin_memory,
|
||
memory_format=memory_format,
|
||
)
|
||
|
||
nvprim_impl.impl(name, _prim_impl)
|
||
nvprim_meta_impl.impl(name, _meta_rand_like)
|
||
|
||
prim_packet = getattr(torch._ops.ops.nvprims, name)
|
||
prim = prim_packet.default
|
||
|
||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||
|
||
for p in (prim_packet, prim):
|
||
p.__doc__ = "Computes rand_like"
|
||
p.impl_nvfuser = _nvfuser_impls["rand_like"]
|
||
p.is_recomputable = _nvfuser_is_recomputable["rand_like"]
|
||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||
|
||
|
||
def register_var_mean():
|
||
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
|
||
name = "var_mean.main"
|
||
|
||
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
|
||
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
|
||
|
||
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
|
||
nvprim.define(
|
||
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, float? correction=None)"
|
||
+ " -> (Tensor, Tensor)"
|
||
)
|
||
|
||
# This function is used for device="meta" Tensors.
|
||
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||
if torch._prims_common.is_complex_dtype(inp.dtype):
|
||
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
|
||
else:
|
||
output_dtype = inp.dtype
|
||
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
|
||
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
|
||
if keepdim:
|
||
output_shape = [
|
||
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
|
||
]
|
||
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
|
||
var = torch._ops.ops.nvprims.broadcast_in_dim(
|
||
var, output_shape, broadcast_dims
|
||
)
|
||
mean = torch._ops.ops.nvprims.broadcast_in_dim(
|
||
mean, output_shape, broadcast_dims
|
||
)
|
||
return (var, mean)
|
||
|
||
# This function is used under _AutoDispatchBelowAutograd context
|
||
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||
correction = torch._prims_common.set_correction(unbiased, correction)
|
||
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
|
||
|
||
nvprim_impl.impl(name, _prim_impl)
|
||
nvprim_meta_impl.impl(name, _meta_var_mean)
|
||
|
||
prim_packet = torch._ops.ops.nvprims.var_mean
|
||
prim = prim_packet.main
|
||
|
||
def _unbiased_overload_impl(inp, unbiased):
|
||
return prim(inp, dim=None, unbiased=unbiased)
|
||
|
||
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
|
||
|
||
@elementwise_type_promotion_wrapper(
|
||
type_promoting_args=("a",),
|
||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
|
||
)
|
||
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||
correction = torch._prims_common.set_correction(unbiased, correction)
|
||
# reduces over all dimensions if dim=() is passed
|
||
if dim == () or dim == []:
|
||
dim = None
|
||
dim = torch._prims_common.reduction_dims(a.shape, dim)
|
||
|
||
# For complex tensors eager computes the variance as the sum of variances of
|
||
# the real and imaginary parts
|
||
# TODO: Creating a complex tensor from real and imaginary parts is not supported
|
||
if torch._prims_common.is_complex_dtype(a.dtype):
|
||
raise NotImplementedError("Complex tensors are not supported")
|
||
|
||
var_mean = prim(a, dim, correction=correction)
|
||
|
||
if keepdim:
|
||
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
|
||
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
|
||
var, mean = var_mean
|
||
var = torch._ops.ops.nvprims.broadcast_in_dim(
|
||
var, output_shape, broadcast_dims
|
||
)
|
||
mean = torch._ops.ops.nvprims.broadcast_in_dim(
|
||
mean, output_shape, broadcast_dims
|
||
)
|
||
var_mean = (var, mean)
|
||
return var_mean
|
||
|
||
def _var_mean_autograd(
|
||
a, dim=None, unbiased=None, keepdim=False, *, correction=None
|
||
):
|
||
# This wrapper is needed to convert prims calls inside
|
||
# elementwise_type_promotion_wrapper to nvprims calls
|
||
from torch._prims.context import NvfuserPrimsMode
|
||
|
||
with NvfuserPrimsMode():
|
||
return backwards_not_supported(_var_mean_ref)(
|
||
a, dim, unbiased, keepdim, correction=correction
|
||
)
|
||
|
||
nvprim_autograd_impl.impl(name, _var_mean_autograd)
|
||
|
||
for p in (prim_packet, prim):
|
||
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
|
||
p.impl_nvfuser = _nvfuser_impls["var_mean"]
|
||
p.is_recomputable = _nvfuser_is_recomputable["var_mean"]
|
||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||
|
||
|
||
def _nvprims_view_impl_aten(a, original_shape, new_shape):
|
||
return a.reshape(new_shape)
|
||
|
||
|
||
def register_view():
|
||
"""This function is used to register the view function in torch.ops.view module."""
|
||
# View is implemented as a decomposition into prims.split_dim,
|
||
# prims.collapse_dim, and prims.reshape, but we would like to intercept
|
||
# non-decomposed view for now
|
||
name = "view"
|
||
|
||
nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
|
||
nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
|
||
|
||
# This function is used under _AutoDispatchBelowAutograd context
|
||
def _prim_impl(a, original_shape, new_shape):
|
||
return a.reshape(new_shape)
|
||
|
||
nvprim_impl.impl(name, _prim_impl)
|
||
|
||
prim_packet = torch._ops.ops.nvprims.view
|
||
prim = prim_packet.default
|
||
|
||
def _view_no_original_shape_overload_impl(a, shape):
|
||
if list(a.shape) == list(shape):
|
||
return torch.ops.nvprims.view_of(a)
|
||
return torch.ops.nvprims.view.default(a, a.shape, shape)
|
||
|
||
nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
|
||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||
|
||
for p in (prim_packet, prim):
|
||
p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
|
||
p.impl_nvfuser = _nvfuser_impls["view"]
|
||
p.is_recomputable = _nvfuser_is_recomputable["view"]
|
||
p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined]
|
||
p.impl_aten = _nvprims_view_impl_aten
|
||
|
||
|
||
def register_nvprims():
|
||
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
|
||
register_var_mean()
|
||
register_view()
|
||
register_native_batch_norm()
|
||
register_rand_like()
|
||
register_full()
|
||
|
||
for name in nvprim_names:
|
||
main_prim = getattr(torch._ops.ops.prims, name)
|
||
|
||
nvprim.define(main_prim.schema)
|
||
nvprim_impl.impl(name, main_prim.prim_impl)
|
||
nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
|
||
|
||
prim_packet = getattr(torch._ops.ops.nvprims, name)
|
||
prim = prim_packet.default
|
||
|
||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||
|
||
for p in (prim_packet, prim):
|
||
p.__doc__ = main_prim.__doc__
|
||
p.impl_nvfuser = _nvfuser_impls[name]
|
||
p.is_recomputable = _nvfuser_is_recomputable.get(name, False)
|
||
p.return_type = main_prim.return_type # type: ignore[attr-defined]
|
||
p.impl_aten = main_prim.impl_aten
|