Add nvprims.var_mean (#83508)

This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
This commit is contained in:
Ivan Yashchuk
2022-08-28 18:45:25 +00:00
committed by PyTorch MergeBot
parent 261be8e5c2
commit 3aae6ff1e1
15 changed files with 381 additions and 44 deletions

View File

@ -188,6 +188,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
1e-3,
1e-3,
),
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
}
if (test_dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]

View File

@ -353,6 +353,31 @@ class TestPrims(TestCase):
self.assertTrue(result.is_contiguous)
self.assertEqual(_wrapper(a), result)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
@parametrize("correction", [0, 1])
@parametrize("keepdim", [True, False])
def test_var_mean(self, device, dtype, correction, keepdim):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
def _wrapper(a):
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
make_arg = partial(make_tensor, device=device, dtype=dtype)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(_wrapper)(make_arg((5, 5)))
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_nvprims_var_mean = any(
torch.ops.nvprims.var_mean.main == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_var_mean)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)

View File

@ -20,6 +20,7 @@ from torch._prims_common import (
DimsType,
Number,
NumberType,
RETURN_TYPE,
ShapeType,
StrideType,
TensorLike,
@ -280,20 +281,6 @@ def TensorMeta(
#
# Common datastructures and helpers
#
# Describes the return type of the primitive:
#
# - NEW, a new tensor is created
# - VIEW, a view of an input tensor is returned
# - INPLACE, one or more input tensors is modified
#
# these descriptors are mututally exclusive and exhaustive.
class RETURN_TYPE(Enum):
NEW = (0,)
VIEW = (1,)
INPLACE = (2,)
def _wrap_tensor_meta(f):
def wrap(t):
if (

View File

@ -12,6 +12,7 @@ import torch._refs.nn
import torch._refs.nn.functional
import torch._refs.special
import torch.overrides
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch._prims_common import torch_function_passthrough
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
@ -204,15 +205,42 @@ def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
):
gm = get_isolated_graphmodule(func, args, kwargs)
supported_ops = NvfuserPrimOperatorSupport()
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
any_unsupported = any(
not _is_node_supported_nvfuser(node) for node in call_function_nodes
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
)
return any_unsupported
TorchRefsNvfuserCapabilityMode = functools.partial(
TorchRefsMode,
should_fallback_fn=_is_func_unsupported_nvfuser,
prims_mode_cls=NvfuserPrimsMode,
)
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
def __init__(self):
super().__init__(
strict=False,
should_fallback_fn=_is_func_unsupported_nvfuser,
prims_mode_cls=NvfuserPrimsMode,
)
def _is_var_mean(self, func):
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
(
isinstance(func, torch._ops.OpOverload)
or isinstance(func, torch._ops.OpOverloadPacket)
)
and "aten.var_mean" in str(func)
)
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)
# Then we use TorchRefsMode to interpret the rest
return super().__torch_function__(orig_func, types, args, kwargs)

View File

@ -57,6 +57,8 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
# PROTOTYPE nvfuser executor
# Everything in the graph must support nvfuser
for node in gm.graph.nodes:
if node.op == "call_function" and "getitem" in node.name:
continue
if (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is None
@ -77,6 +79,10 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
class FusionInterpreter(torch.fx.Interpreter):
def call_function(self, target, args, kwargs):
# This handles tuple unpacking
if "getitem" in str(target):
assert isinstance(args[0], tuple)
return target(*args, **kwargs)
args = tuple(map(_to_nvfuser_constant, args))
target = target.impl_nvfuser
args = (fd,) + args
@ -132,6 +138,7 @@ class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSuppor
return (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is not None
or "getitem" in node.name # getitem is a special case
)

View File

@ -5,24 +5,31 @@
# can be added in the future for the corresponding higher-level torch/aten
# functions.
from typing import Any, Dict
from typing import Any, Dict, Optional
import torch
from torch._prims_common import (
DimsSequenceType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
getnvFuserDtype,
ShapeType,
TensorLikeType,
)
from torch._prims_common.wrappers import backwards_not_supported
from torch._prims_common.wrappers import (
backwards_not_supported,
elementwise_type_promotion_wrapper,
)
nvprim_namespace = "nvprims"
nvprim = torch.library.Library(nvprim_namespace, "DEF")
nvprim_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
)
nvprim_implicit_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
)
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
@ -234,6 +241,23 @@ def _var_nvfuser(
return fd.ops.var(a, dims, correction, keep_dims)
def _var_mean_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: int,
):
# Unbiased arg shouldn't be set when this function is called
assert unbiased is None
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
# keepdim is handled by the reference implementation
keepdim = False
return fd.ops.var_mean(a, dims, correction, keepdim)
def _amax_nvfuser(
fd: Any,
a: TensorLikeType,
@ -256,12 +280,112 @@ _nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
_nvfuser_impls["sum"] = _sum_nvfuser
_nvfuser_impls["var"] = _var_nvfuser
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
_nvfuser_impls["amax"] = _amax_nvfuser
_nvfuser_impls["amin"] = _amin_nvfuser
def register_var_mean():
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
name = "var_mean.main"
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
nvprim.define(
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
+ " -> (Tensor, Tensor)"
)
# This function is used for device="meta" Tensors.
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
if torch._prims_common.is_complex_dtype(inp.dtype):
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
else:
output_dtype = inp.dtype
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
if keepdim:
output_shape = [
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
]
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
mean = torch.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
return (var, mean)
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_var_mean)
prim_packet = torch.ops.nvprims.var_mean
prim = prim_packet.main
def _unbiased_overload_impl(inp, unbiased):
return prim(inp, dim=None, unbiased=unbiased)
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
)
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
dim = torch._prims_common.reduction_dims(a.shape, dim)
# For complex tensors eager computes the variance as the sum of variances of
# the real and imaginary parts
# TODO: Creating a complex tensor from real and imaginary parts is not supported
if torch._prims_common.is_complex_dtype(a.dtype):
raise NotImplementedError("Complex tensors are not supported")
var_mean = prim(a, dim, correction=correction)
if keepdim:
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
var, mean = var_mean
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
mean = torch.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
var_mean = (var, mean)
return var_mean
def _var_mean_autograd(
a, dim=None, unbiased=None, keepdim=False, *, correction=None
):
# This wrapper is needed to convert prims calls inside
# elementwise_type_promotion_wrapper to nvprims calls
from torch._prims.context import NvfuserPrimsMode
with NvfuserPrimsMode():
return backwards_not_supported(_var_mean_ref)(
a, dim, unbiased, keepdim, correction=correction
)
nvprim_autograd_impl.impl(name, _var_mean_autograd)
for p in (prim_packet, prim):
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
p.impl_nvfuser = _nvfuser_impls["var_mean"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_nvprims():
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
register_var_mean()
for name in nvprim_names:
main_prim = getattr(torch.ops.prims, name)

View File

@ -1033,6 +1033,19 @@ class REDUCTION_OUTPUT_TYPE_KIND(Enum):
ALWAYS_BOOL = (3,)
# Describes the return type of the primitive:
#
# - NEW, a new tensor is created
# - VIEW, a view of an input tensor is returned
# - INPLACE, one or more input tensors is modified
#
# these descriptors are mututally exclusive and exhaustive.
class RETURN_TYPE(Enum):
NEW = (0,)
VIEW = (1,)
INPLACE = (2,)
# TODO: document type promotion kinds
def elementwise_dtypes(
*_args,
@ -1348,6 +1361,23 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...
return dims
def set_correction(
unbiased: Optional[bool] = None,
correction: Optional[int] = None,
):
if correction is not None and unbiased is not None:
raise RuntimeError("cannot specify both correction and unbiased arguments")
elif correction is None and unbiased is None:
correction = 1
elif correction is None and unbiased is not None:
correction = 0 if unbiased is False else 1
if not isinstance(correction, int):
raise ValueError("correction argument should be integer")
if correction < 0:
raise ValueError("correction argument should be non-negative")
return correction
def check_in_bounds_for_storage(
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
):

View File

@ -118,9 +118,11 @@ class elementwise_type_promotion_wrapper(object):
result = fn(**bound.arguments)
# FIXME?: assumes result is a single tensor
assert isinstance(result, TensorLike)
return _maybe_convert_to_dtype(result, result_dtype)
if isinstance(result, TensorLike):
return _maybe_convert_to_dtype(result, result_dtype)
if isinstance(result, Sequence):
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
raise AssertionError(f"Unhandled result type: {type(result)}")
_fn.__signature__ = sig # type: ignore[attr-defined]
return _fn

View File

@ -1897,21 +1897,14 @@ def amax(
)
def _set_correction(
unbiased: Optional[bool] = None,
correction: Optional[int] = None,
):
if correction is not None and unbiased is not None:
raise RuntimeError("cannot specify both correction and unbiased arguments")
elif correction is None and unbiased is None:
correction = 1
elif correction is None and unbiased is not None:
correction = 0 if unbiased is False else 1
if not isinstance(correction, int):
raise ValueError("correction argument should be integer")
if correction < 0:
raise ValueError("correction argument should be non-negative")
return correction
def _dim_var_dispatch(dim=None, unbiased=None):
# There's the following overload of torch.var:
# var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
# We need to explicitly convert bool dims to unbiased arg
if unbiased is None and isinstance(dim, bool):
unbiased = dim
dim = None
return dim, unbiased
@out_wrapper()
@ -1923,7 +1916,8 @@ def var(
*,
correction: Optional[int] = None,
) -> TensorLikeType:
correction = _set_correction(unbiased, correction)
dim, unbiased = _dim_var_dispatch(dim, unbiased)
correction = utils.set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
@ -1950,7 +1944,8 @@ def std(
*,
correction: Optional[int] = None,
) -> TensorLikeType:
correction = _set_correction(unbiased, correction)
dim, unbiased = _dim_var_dispatch(dim, unbiased)
correction = utils.set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
@ -2024,6 +2019,7 @@ def std_mean(
keepdim: bool = False,
correction: Optional[int] = None,
):
dim, unbiased = _dim_var_dispatch(dim, unbiased)
s = std(a, dim, unbiased, keepdim, correction=correction)
m = mean(a, dim, keepdim)
return s, m
@ -2038,6 +2034,7 @@ def var_mean(
*,
correction: Optional[int] = None,
):
dim, unbiased = _dim_var_dispatch(dim, unbiased)
v = var(a, dim, unbiased, keepdim, correction=correction)
m = mean(a, dim, keepdim)
return v, m
@ -2451,7 +2448,9 @@ def _normalize(
computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
biased_var, mean = var_mean(a_acc, dim=norm_dims, unbiased=False, keepdim=True)
biased_var, mean = torch.var_mean(
a_acc, dim=norm_dims, unbiased=False, keepdim=True
)
rstd = torch.rsqrt(biased_var + eps)
out = (a - mean) * rstd
return out, mean, rstd

View File

@ -69,6 +69,64 @@ TensorView* variance(
return y;
}
TORCH_CUDA_CU_API VarMeanResult variance_mean(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim) {
TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid.");
TORCH_CHECK(
correction >= 0, "correction must be non-negative, but got ", correction);
// There are compilation errors for half precision
auto dtype = x->getDataType().value();
TORCH_CHECK(
!(dtype == DataType::Half || dtype == DataType::BFloat16),
"variance_mean is not supported for ",
dtype,
" please upcast to float");
if (isComplexType(x->getDataType().value())) {
// There are compilation errors:
// __tmp_kernel1.cu(6727): error: namespace "CudaCodeGen::std" has no member
// "imagf"
// __tmp_kernel1.cu(6753): error: namespace "CudaCodeGen::std" has no member
// "realf"
TORCH_CHECK(false, "var_mean is not supported for complex types.");
auto out_real = variance_mean(real(x), dims, correction, keepdim);
auto out_imag = variance_mean(imag(x), dims, correction, keepdim);
// variance of a complex tensor is the sum of real and imaginary variances
// and is real mean of a complex tensor is complex complex(out_real.mean,
// out_imag.mean) It seems construction of a complex tensor from two real
// tensors is not supported yet
return {add(out_real.var, out_imag.var), nullptr};
}
const int kNumberOfDims =
TensorDomain::noReductions(x->getMaybeRFactorDomain()).size();
auto num_features = numFeatures(x, dims, kNumberOfDims);
if (correction > 0) {
num_features =
sub(num_features, IrBuilder::create<Int>(x->container(), correction));
}
auto welford_out = Welford(x, dims);
auto mean = welford_out.avg;
auto var = mul(welford_out.var_sum, reciprocal(num_features));
if (keepdim) {
std::vector<bool> is_broadcast(kNumberOfDims, false);
for (auto dim : dims) {
is_broadcast[dim] = true;
}
var = broadcast(var, is_broadcast);
mean = broadcast(mean, is_broadcast);
}
return {var, mean};
}
TensorView* standard_deviation(
TensorView* x,
const std::vector<int>& dims,

View File

@ -38,6 +38,11 @@ struct BackwardRMSNormResult {
TensorView* grad_weight = nullptr;
};
struct VarMeanResult {
TensorView* var = nullptr;
TensorView* mean = nullptr;
};
TORCH_CUDA_CU_API TensorView* mean(
TensorView* x,
const std::vector<int>& dims,
@ -55,6 +60,12 @@ TORCH_CUDA_CU_API TensorView* variance(
int64_t correction,
bool keepdim);
TORCH_CUDA_CU_API VarMeanResult variance_mean(
TensorView* x,
const std::vector<int>& dims,
int64_t correction,
bool keepdim);
TORCH_CUDA_CU_API TensorView* standard_deviation(
TensorView* x,
const std::vector<int>& dims,

View File

@ -362,4 +362,34 @@ struct VarianceOpRecord : RecordFunctor {
bool keep_dim_;
};
struct VarianceMeanOpRecord : RecordFunctor {
VarianceMeanOpRecord(
std::vector<size_t> _args,
std::vector<size_t> _outputs,
std::vector<int>& dims,
int64_t correction,
bool keepdim)
: RecordFunctor(std::move(_args), std::move(_outputs)),
dims_(dims),
correction_(correction),
keepdim_(keepdim) {}
virtual ~VarianceMeanOpRecord() = default;
void operator()(FusionDefinition& fd) final {
auto arg = fd.getFusionState(args.at(0))->as<NvfTensorView>();
auto output = torch::jit::fuser::cuda::variance_mean(
arg, dims_, correction_, keepdim_);
fd.setFusionState(outputs.at(0), output.var);
fd.setFusionState(outputs.at(1), output.mean);
}
private:
//! Dimensions of tensor to reduce for variance calculation
std::vector<int> dims_;
//! Bessel's correction value
int64_t correction_;
//! Indicates whether to keep the reduced dimension(s).
bool keepdim_;
};
} // namespace nvfuser

View File

@ -12,6 +12,7 @@
#include <torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <iostream>
#include <tuple>
namespace torch {
namespace jit {
@ -931,6 +932,25 @@ void initNvFuserPythonBindings(PyObject* module) {
},
py::return_value_policy::reference);
nvf_ops.def(
"var_mean",
[](nvfuser::FusionDefinition::Operators& self,
nvfuser::Tensor* arg,
std::vector<int>& dims,
int64_t correction,
bool keepdim) -> decltype(auto) {
nvfuser::Tensor* var = self.fusion_definition->defineTensor();
nvfuser::Tensor* mean = self.fusion_definition->defineTensor();
self.fusion_definition->defineRecord(new nvfuser::VarianceMeanOpRecord(
{arg->index},
{var->index, mean->index},
dims,
correction,
keepdim));
return std::make_tuple(var, mean);
},
py::return_value_policy::reference);
nvf_ops.def(
"broadcast_in_dim",
[](nvfuser::FusionDefinition::Operators& self,

View File

@ -242,7 +242,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
// return an IValue() to denote a NoneType
return {};
}
return toIValue(obj, type->expectRef<OptionalType>().getElementType());
return toIValue(obj, type->expectRef<OptionalType>().getElementType(), N);
}
case TypeKind::ClassType: {
auto classType = type->expect<ClassType>();

View File

@ -4615,6 +4615,10 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
SampleInput(tensor_nd(), kwargs=dict(dim=(1,), correction=S // 2)),
SampleInput(tensor_nd(), kwargs=dict(dim=None, correction=0, keepdim=True)),
# Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
SampleInput(tensor_nd(), args=(True,)),
SampleInput(tensor_nd(), args=(False,)),
]
@ -16970,6 +16974,17 @@ python_ref_db = [
torch_opinfo_name="var_mean",
validate_view_consistency=False,
),
PythonRefInfo(
"ops.nvprims.var_mean",
torch_opinfo_name="var_mean",
validate_view_consistency=False,
# Complex types are currently disabled
dtypes=floating_types_and(torch.float16, torch.bfloat16),
# This function is expected not to work with TorchRefsMode(strict=True)
decorators=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
),
),
#
# Linear Algebra Operators
#