mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
261be8e5c2
commit
3aae6ff1e1
@ -188,6 +188,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
1e-3,
|
||||
1e-3,
|
||||
),
|
||||
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
|
||||
}
|
||||
if (test_dtype, op) in tol_table:
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
|
||||
@ -353,6 +353,31 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(result.is_contiguous)
|
||||
self.assertEqual(_wrapper(a), result)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
@parametrize("correction", [0, 1])
|
||||
@parametrize("keepdim", [True, False])
|
||||
def test_var_mean(self, device, dtype, correction, keepdim):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
|
||||
def _wrapper(a):
|
||||
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(_wrapper)(make_arg((5, 5)))
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_nvprims_var_mean = any(
|
||||
torch.ops.nvprims.var_mean.main == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_var_mean)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
|
||||
@ -20,6 +20,7 @@ from torch._prims_common import (
|
||||
DimsType,
|
||||
Number,
|
||||
NumberType,
|
||||
RETURN_TYPE,
|
||||
ShapeType,
|
||||
StrideType,
|
||||
TensorLike,
|
||||
@ -280,20 +281,6 @@ def TensorMeta(
|
||||
#
|
||||
# Common datastructures and helpers
|
||||
#
|
||||
|
||||
# Describes the return type of the primitive:
|
||||
#
|
||||
# - NEW, a new tensor is created
|
||||
# - VIEW, a view of an input tensor is returned
|
||||
# - INPLACE, one or more input tensors is modified
|
||||
#
|
||||
# these descriptors are mututally exclusive and exhaustive.
|
||||
class RETURN_TYPE(Enum):
|
||||
NEW = (0,)
|
||||
VIEW = (1,)
|
||||
INPLACE = (2,)
|
||||
|
||||
|
||||
def _wrap_tensor_meta(f):
|
||||
def wrap(t):
|
||||
if (
|
||||
|
||||
@ -12,6 +12,7 @@ import torch._refs.nn
|
||||
import torch._refs.nn.functional
|
||||
import torch._refs.special
|
||||
import torch.overrides
|
||||
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
|
||||
|
||||
from torch._prims_common import torch_function_passthrough
|
||||
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
|
||||
@ -204,15 +205,42 @@ def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
|
||||
):
|
||||
gm = get_isolated_graphmodule(func, args, kwargs)
|
||||
|
||||
supported_ops = NvfuserPrimOperatorSupport()
|
||||
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
any_unsupported = any(
|
||||
not _is_node_supported_nvfuser(node) for node in call_function_nodes
|
||||
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
|
||||
)
|
||||
return any_unsupported
|
||||
|
||||
|
||||
TorchRefsNvfuserCapabilityMode = functools.partial(
|
||||
TorchRefsMode,
|
||||
should_fallback_fn=_is_func_unsupported_nvfuser,
|
||||
prims_mode_cls=NvfuserPrimsMode,
|
||||
)
|
||||
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
strict=False,
|
||||
should_fallback_fn=_is_func_unsupported_nvfuser,
|
||||
prims_mode_cls=NvfuserPrimsMode,
|
||||
)
|
||||
|
||||
def _is_var_mean(self, func):
|
||||
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
|
||||
(
|
||||
isinstance(func, torch._ops.OpOverload)
|
||||
or isinstance(func, torch._ops.OpOverloadPacket)
|
||||
)
|
||||
and "aten.var_mean" in str(func)
|
||||
)
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Dict = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
|
||||
if self._is_var_mean(orig_func):
|
||||
return torch.ops.nvprims.var_mean(*args, **kwargs)
|
||||
# Then we use TorchRefsMode to interpret the rest
|
||||
return super().__torch_function__(orig_func, types, args, kwargs)
|
||||
|
||||
@ -57,6 +57,8 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
# PROTOTYPE nvfuser executor
|
||||
# Everything in the graph must support nvfuser
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and "getitem" in node.name:
|
||||
continue
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "impl_nvfuser", None) is None
|
||||
@ -77,6 +79,10 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
|
||||
class FusionInterpreter(torch.fx.Interpreter):
|
||||
def call_function(self, target, args, kwargs):
|
||||
# This handles tuple unpacking
|
||||
if "getitem" in str(target):
|
||||
assert isinstance(args[0], tuple)
|
||||
return target(*args, **kwargs)
|
||||
args = tuple(map(_to_nvfuser_constant, args))
|
||||
target = target.impl_nvfuser
|
||||
args = (fd,) + args
|
||||
@ -132,6 +138,7 @@ class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSuppor
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "impl_nvfuser", None) is not None
|
||||
or "getitem" in node.name # getitem is a special case
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -5,24 +5,31 @@
|
||||
# can be added in the future for the corresponding higher-level torch/aten
|
||||
# functions.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from torch._prims_common import (
|
||||
DimsSequenceType,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
getnvFuserDtype,
|
||||
ShapeType,
|
||||
TensorLikeType,
|
||||
)
|
||||
|
||||
from torch._prims_common.wrappers import backwards_not_supported
|
||||
from torch._prims_common.wrappers import (
|
||||
backwards_not_supported,
|
||||
elementwise_type_promotion_wrapper,
|
||||
)
|
||||
|
||||
nvprim_namespace = "nvprims"
|
||||
nvprim = torch.library.Library(nvprim_namespace, "DEF")
|
||||
nvprim_impl = torch.library.Library(
|
||||
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
|
||||
)
|
||||
nvprim_implicit_impl = torch.library.Library(
|
||||
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
|
||||
)
|
||||
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
|
||||
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
|
||||
|
||||
@ -234,6 +241,23 @@ def _var_nvfuser(
|
||||
return fd.ops.var(a, dims, correction, keep_dims)
|
||||
|
||||
|
||||
def _var_mean_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
dims: DimsSequenceType,
|
||||
unbiased: Optional[bool] = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
correction: int,
|
||||
):
|
||||
# Unbiased arg shouldn't be set when this function is called
|
||||
assert unbiased is None
|
||||
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
|
||||
# keepdim is handled by the reference implementation
|
||||
keepdim = False
|
||||
return fd.ops.var_mean(a, dims, correction, keepdim)
|
||||
|
||||
|
||||
def _amax_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
@ -256,12 +280,112 @@ _nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
|
||||
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
|
||||
_nvfuser_impls["sum"] = _sum_nvfuser
|
||||
_nvfuser_impls["var"] = _var_nvfuser
|
||||
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
|
||||
_nvfuser_impls["amax"] = _amax_nvfuser
|
||||
_nvfuser_impls["amin"] = _amin_nvfuser
|
||||
|
||||
|
||||
def register_var_mean():
|
||||
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
|
||||
name = "var_mean.main"
|
||||
|
||||
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
|
||||
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
|
||||
|
||||
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
|
||||
nvprim.define(
|
||||
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
|
||||
+ " -> (Tensor, Tensor)"
|
||||
)
|
||||
|
||||
# This function is used for device="meta" Tensors.
|
||||
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||||
if torch._prims_common.is_complex_dtype(inp.dtype):
|
||||
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
|
||||
else:
|
||||
output_dtype = inp.dtype
|
||||
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
|
||||
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
|
||||
if keepdim:
|
||||
output_shape = [
|
||||
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
|
||||
]
|
||||
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
|
||||
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
|
||||
mean = torch.ops.nvprims.broadcast_in_dim(
|
||||
mean, output_shape, broadcast_dims
|
||||
)
|
||||
return (var, mean)
|
||||
|
||||
# This function is used under _AutoDispatchBelowAutograd context
|
||||
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||||
correction = torch._prims_common.set_correction(unbiased, correction)
|
||||
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
nvprim_meta_impl.impl(name, _meta_var_mean)
|
||||
|
||||
prim_packet = torch.ops.nvprims.var_mean
|
||||
prim = prim_packet.main
|
||||
|
||||
def _unbiased_overload_impl(inp, unbiased):
|
||||
return prim(inp, dim=None, unbiased=unbiased)
|
||||
|
||||
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
|
||||
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
|
||||
)
|
||||
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||||
correction = torch._prims_common.set_correction(unbiased, correction)
|
||||
# reduces over all dimensions if dim=() is passed
|
||||
if dim == () or dim == []:
|
||||
dim = None
|
||||
dim = torch._prims_common.reduction_dims(a.shape, dim)
|
||||
|
||||
# For complex tensors eager computes the variance as the sum of variances of
|
||||
# the real and imaginary parts
|
||||
# TODO: Creating a complex tensor from real and imaginary parts is not supported
|
||||
if torch._prims_common.is_complex_dtype(a.dtype):
|
||||
raise NotImplementedError("Complex tensors are not supported")
|
||||
|
||||
var_mean = prim(a, dim, correction=correction)
|
||||
|
||||
if keepdim:
|
||||
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
|
||||
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
|
||||
var, mean = var_mean
|
||||
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
|
||||
mean = torch.ops.nvprims.broadcast_in_dim(
|
||||
mean, output_shape, broadcast_dims
|
||||
)
|
||||
var_mean = (var, mean)
|
||||
return var_mean
|
||||
|
||||
def _var_mean_autograd(
|
||||
a, dim=None, unbiased=None, keepdim=False, *, correction=None
|
||||
):
|
||||
# This wrapper is needed to convert prims calls inside
|
||||
# elementwise_type_promotion_wrapper to nvprims calls
|
||||
from torch._prims.context import NvfuserPrimsMode
|
||||
|
||||
with NvfuserPrimsMode():
|
||||
return backwards_not_supported(_var_mean_ref)(
|
||||
a, dim, unbiased, keepdim, correction=correction
|
||||
)
|
||||
|
||||
nvprim_autograd_impl.impl(name, _var_mean_autograd)
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
|
||||
p.impl_nvfuser = _nvfuser_impls["var_mean"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_nvprims():
|
||||
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
|
||||
register_var_mean()
|
||||
for name in nvprim_names:
|
||||
main_prim = getattr(torch.ops.prims, name)
|
||||
|
||||
|
||||
@ -1033,6 +1033,19 @@ class REDUCTION_OUTPUT_TYPE_KIND(Enum):
|
||||
ALWAYS_BOOL = (3,)
|
||||
|
||||
|
||||
# Describes the return type of the primitive:
|
||||
#
|
||||
# - NEW, a new tensor is created
|
||||
# - VIEW, a view of an input tensor is returned
|
||||
# - INPLACE, one or more input tensors is modified
|
||||
#
|
||||
# these descriptors are mututally exclusive and exhaustive.
|
||||
class RETURN_TYPE(Enum):
|
||||
NEW = (0,)
|
||||
VIEW = (1,)
|
||||
INPLACE = (2,)
|
||||
|
||||
|
||||
# TODO: document type promotion kinds
|
||||
def elementwise_dtypes(
|
||||
*_args,
|
||||
@ -1348,6 +1361,23 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...
|
||||
return dims
|
||||
|
||||
|
||||
def set_correction(
|
||||
unbiased: Optional[bool] = None,
|
||||
correction: Optional[int] = None,
|
||||
):
|
||||
if correction is not None and unbiased is not None:
|
||||
raise RuntimeError("cannot specify both correction and unbiased arguments")
|
||||
elif correction is None and unbiased is None:
|
||||
correction = 1
|
||||
elif correction is None and unbiased is not None:
|
||||
correction = 0 if unbiased is False else 1
|
||||
if not isinstance(correction, int):
|
||||
raise ValueError("correction argument should be integer")
|
||||
if correction < 0:
|
||||
raise ValueError("correction argument should be non-negative")
|
||||
return correction
|
||||
|
||||
|
||||
def check_in_bounds_for_storage(
|
||||
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
|
||||
):
|
||||
|
||||
@ -118,9 +118,11 @@ class elementwise_type_promotion_wrapper(object):
|
||||
|
||||
result = fn(**bound.arguments)
|
||||
|
||||
# FIXME?: assumes result is a single tensor
|
||||
assert isinstance(result, TensorLike)
|
||||
return _maybe_convert_to_dtype(result, result_dtype)
|
||||
if isinstance(result, TensorLike):
|
||||
return _maybe_convert_to_dtype(result, result_dtype)
|
||||
if isinstance(result, Sequence):
|
||||
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
|
||||
raise AssertionError(f"Unhandled result type: {type(result)}")
|
||||
|
||||
_fn.__signature__ = sig # type: ignore[attr-defined]
|
||||
return _fn
|
||||
|
||||
@ -1897,21 +1897,14 @@ def amax(
|
||||
)
|
||||
|
||||
|
||||
def _set_correction(
|
||||
unbiased: Optional[bool] = None,
|
||||
correction: Optional[int] = None,
|
||||
):
|
||||
if correction is not None and unbiased is not None:
|
||||
raise RuntimeError("cannot specify both correction and unbiased arguments")
|
||||
elif correction is None and unbiased is None:
|
||||
correction = 1
|
||||
elif correction is None and unbiased is not None:
|
||||
correction = 0 if unbiased is False else 1
|
||||
if not isinstance(correction, int):
|
||||
raise ValueError("correction argument should be integer")
|
||||
if correction < 0:
|
||||
raise ValueError("correction argument should be non-negative")
|
||||
return correction
|
||||
def _dim_var_dispatch(dim=None, unbiased=None):
|
||||
# There's the following overload of torch.var:
|
||||
# var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
|
||||
# We need to explicitly convert bool dims to unbiased arg
|
||||
if unbiased is None and isinstance(dim, bool):
|
||||
unbiased = dim
|
||||
dim = None
|
||||
return dim, unbiased
|
||||
|
||||
|
||||
@out_wrapper()
|
||||
@ -1923,7 +1916,8 @@ def var(
|
||||
*,
|
||||
correction: Optional[int] = None,
|
||||
) -> TensorLikeType:
|
||||
correction = _set_correction(unbiased, correction)
|
||||
dim, unbiased = _dim_var_dispatch(dim, unbiased)
|
||||
correction = utils.set_correction(unbiased, correction)
|
||||
# reduces over all dimensions if dim=() is passed
|
||||
if dim == () or dim == []:
|
||||
dim = None
|
||||
@ -1950,7 +1944,8 @@ def std(
|
||||
*,
|
||||
correction: Optional[int] = None,
|
||||
) -> TensorLikeType:
|
||||
correction = _set_correction(unbiased, correction)
|
||||
dim, unbiased = _dim_var_dispatch(dim, unbiased)
|
||||
correction = utils.set_correction(unbiased, correction)
|
||||
# reduces over all dimensions if dim=() is passed
|
||||
if dim == () or dim == []:
|
||||
dim = None
|
||||
@ -2024,6 +2019,7 @@ def std_mean(
|
||||
keepdim: bool = False,
|
||||
correction: Optional[int] = None,
|
||||
):
|
||||
dim, unbiased = _dim_var_dispatch(dim, unbiased)
|
||||
s = std(a, dim, unbiased, keepdim, correction=correction)
|
||||
m = mean(a, dim, keepdim)
|
||||
return s, m
|
||||
@ -2038,6 +2034,7 @@ def var_mean(
|
||||
*,
|
||||
correction: Optional[int] = None,
|
||||
):
|
||||
dim, unbiased = _dim_var_dispatch(dim, unbiased)
|
||||
v = var(a, dim, unbiased, keepdim, correction=correction)
|
||||
m = mean(a, dim, keepdim)
|
||||
return v, m
|
||||
@ -2451,7 +2448,9 @@ def _normalize(
|
||||
computation_dtype = utils.get_computation_dtype(a.dtype)
|
||||
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
|
||||
assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
|
||||
biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True)
|
||||
biased_var, mean = torch.var_mean(
|
||||
a_acc, dim=norm_dims, unbiased=False, keepdim=True
|
||||
)
|
||||
rstd = torch.rsqrt(biased_var + eps)
|
||||
out = (a - mean) * rstd
|
||||
return out, mean, rstd
|
||||
|
||||
@ -69,6 +69,64 @@ TensorView* variance(
|
||||
return y;
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API VarMeanResult variance_mean(
|
||||
TensorView* x,
|
||||
const std::vector<int>& dims,
|
||||
int64_t correction,
|
||||
bool keepdim) {
|
||||
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
|
||||
|
||||
TORCH_CHECK(
|
||||
correction >= 0, "correction must be non-negative, but got ", correction);
|
||||
|
||||
// There are compilation errors for half precision
|
||||
auto dtype = x->getDataType().value();
|
||||
TORCH_CHECK(
|
||||
!(dtype == DataType::Half || dtype == DataType::BFloat16),
|
||||
"variance_mean is not supported for ",
|
||||
dtype,
|
||||
" please upcast to float");
|
||||
|
||||
if (isComplexType(x->getDataType().value())) {
|
||||
// There are compilation errors:
|
||||
// __tmp_kernel1.cu(6727): error: namespace "CudaCodeGen::std" has no member
|
||||
// "imagf"
|
||||
// __tmp_kernel1.cu(6753): error: namespace "CudaCodeGen::std" has no member
|
||||
// "realf"
|
||||
TORCH_CHECK(false, "var_mean is not supported for complex types.");
|
||||
auto out_real = variance_mean(real(x), dims, correction, keepdim);
|
||||
auto out_imag = variance_mean(imag(x), dims, correction, keepdim);
|
||||
// variance of a complex tensor is the sum of real and imaginary variances
|
||||
// and is real mean of a complex tensor is complex complex(out_real.mean,
|
||||
// out_imag.mean) It seems construction of a complex tensor from two real
|
||||
// tensors is not supported yet
|
||||
return {add(out_real.var, out_imag.var), nullptr};
|
||||
}
|
||||
|
||||
const int kNumberOfDims =
|
||||
TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
|
||||
auto num_features = numFeatures(x, dims, kNumberOfDims);
|
||||
if (correction > 0) {
|
||||
num_features =
|
||||
sub(num_features, IrBuilder::create<Int>(x->container(), correction));
|
||||
}
|
||||
|
||||
auto welford_out = Welford(x, dims);
|
||||
auto mean = welford_out.avg;
|
||||
auto var = mul(welford_out.var_sum, reciprocal(num_features));
|
||||
|
||||
if (keepdim) {
|
||||
std::vector<bool> is_broadcast(kNumberOfDims, false);
|
||||
for (auto dim : dims) {
|
||||
is_broadcast[dim] = true;
|
||||
}
|
||||
var = broadcast(var, is_broadcast);
|
||||
mean = broadcast(mean, is_broadcast);
|
||||
}
|
||||
|
||||
return {var, mean};
|
||||
}
|
||||
|
||||
TensorView* standard_deviation(
|
||||
TensorView* x,
|
||||
const std::vector<int>& dims,
|
||||
|
||||
@ -38,6 +38,11 @@ struct BackwardRMSNormResult {
|
||||
TensorView* grad_weight = nullptr;
|
||||
};
|
||||
|
||||
struct VarMeanResult {
|
||||
TensorView* var = nullptr;
|
||||
TensorView* mean = nullptr;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CU_API TensorView* mean(
|
||||
TensorView* x,
|
||||
const std::vector<int>& dims,
|
||||
@ -55,6 +60,12 @@ TORCH_CUDA_CU_API TensorView* variance(
|
||||
int64_t correction,
|
||||
bool keepdim);
|
||||
|
||||
TORCH_CUDA_CU_API VarMeanResult variance_mean(
|
||||
TensorView* x,
|
||||
const std::vector<int>& dims,
|
||||
int64_t correction,
|
||||
bool keepdim);
|
||||
|
||||
TORCH_CUDA_CU_API TensorView* standard_deviation(
|
||||
TensorView* x,
|
||||
const std::vector<int>& dims,
|
||||
|
||||
@ -362,4 +362,34 @@ struct VarianceOpRecord : RecordFunctor {
|
||||
bool keep_dim_;
|
||||
};
|
||||
|
||||
struct VarianceMeanOpRecord : RecordFunctor {
|
||||
VarianceMeanOpRecord(
|
||||
std::vector<size_t> _args,
|
||||
std::vector<size_t> _outputs,
|
||||
std::vector<int>& dims,
|
||||
int64_t correction,
|
||||
bool keepdim)
|
||||
: RecordFunctor(std::move(_args), std::move(_outputs)),
|
||||
dims_(dims),
|
||||
correction_(correction),
|
||||
keepdim_(keepdim) {}
|
||||
virtual ~VarianceMeanOpRecord() = default;
|
||||
|
||||
void operator()(FusionDefinition& fd) final {
|
||||
auto arg = fd.getFusionState(args.at(0))->as<NvfTensorView>();
|
||||
auto output = torch::jit::fuser::cuda::variance_mean(
|
||||
arg, dims_, correction_, keepdim_);
|
||||
fd.setFusionState(outputs.at(0), output.var);
|
||||
fd.setFusionState(outputs.at(1), output.mean);
|
||||
}
|
||||
|
||||
private:
|
||||
//! Dimensions of tensor to reduce for variance calculation
|
||||
std::vector<int> dims_;
|
||||
//! Bessel's correction value
|
||||
int64_t correction_;
|
||||
//! Indicates whether to keep the reduced dimension(s).
|
||||
bool keepdim_;
|
||||
};
|
||||
|
||||
} // namespace nvfuser
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <iostream>
|
||||
#include <tuple>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -931,6 +932,25 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
nvf_ops.def(
|
||||
"var_mean",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
nvfuser::Tensor* arg,
|
||||
std::vector<int>& dims,
|
||||
int64_t correction,
|
||||
bool keepdim) -> decltype(auto) {
|
||||
nvfuser::Tensor* var = self.fusion_definition->defineTensor();
|
||||
nvfuser::Tensor* mean = self.fusion_definition->defineTensor();
|
||||
self.fusion_definition->defineRecord(new nvfuser::VarianceMeanOpRecord(
|
||||
{arg->index},
|
||||
{var->index, mean->index},
|
||||
dims,
|
||||
correction,
|
||||
keepdim));
|
||||
return std::make_tuple(var, mean);
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
|
||||
nvf_ops.def(
|
||||
"broadcast_in_dim",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
|
||||
@ -242,7 +242,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
// return an IValue() to denote a NoneType
|
||||
return {};
|
||||
}
|
||||
return toIValue(obj, type->expectRef<OptionalType>().getElementType());
|
||||
return toIValue(obj, type->expectRef<OptionalType>().getElementType(), N);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
auto classType = type->expect<ClassType>();
|
||||
|
||||
@ -4615,6 +4615,10 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
SampleInput(tensor_nd(), kwargs=dict(dim=(1,), correction=S // 2)),
|
||||
SampleInput(tensor_nd(), kwargs=dict(dim=None, correction=0, keepdim=True)),
|
||||
|
||||
# Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
|
||||
SampleInput(tensor_nd(), args=(True,)),
|
||||
SampleInput(tensor_nd(), args=(False,)),
|
||||
]
|
||||
|
||||
|
||||
@ -16970,6 +16974,17 @@ python_ref_db = [
|
||||
torch_opinfo_name="var_mean",
|
||||
validate_view_consistency=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"ops.nvprims.var_mean",
|
||||
torch_opinfo_name="var_mean",
|
||||
validate_view_consistency=False,
|
||||
# Complex types are currently disabled
|
||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
# This function is expected not to work with TorchRefsMode(strict=True)
|
||||
decorators=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
|
||||
),
|
||||
),
|
||||
#
|
||||
# Linear Algebra Operators
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user