Update amax/amin/norm/count_nonzero signatures with int[*]? dim (#83300)

Changes `dim` arg to use `int[*]?` type for the following functions in `native_funcitons.yaml`:
* `amax`
* `amin`
* `norm`
* `frobenius_norm`
* `native_norm`
* `count_nonzero`

Part of #29137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83300
Approved by: https://github.com/ngimel, https://github.com/albanD, https://github.com/kulinseth
This commit is contained in:
Kurt Mohler
2022-09-28 01:56:37 +00:00
committed by PyTorch MergeBot
parent 80b8886223
commit 1c0f0b33a0
18 changed files with 124 additions and 94 deletions

View File

@ -1 +1 @@
76c65b13280cd5782ace8050df45564ef17891f9
ec3fd75b2d4bbcf62917d0818378aa3a7d0edcb7

View File

@ -79,6 +79,17 @@ inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
}
inline void maybe_wrap_dims(
at::OptionalIntArrayRef opt_dims,
int64_t dim_post_expr) {
if (opt_dims.has_value()) {
at::IntArrayRef dims = opt_dims.value();
// TODO: This const_cast is probably not a good idea...
int64_t* dims_ptr = const_cast<int64_t*>(dims.data());
maybe_wrap_dims_n(dims_ptr, dims.size(), dim_post_expr);
}
}
// previously, size [0] tensors were the only possible empty tensors; thus, it
// wasn't possible to cat empty tensors unless all the other tensors were
// 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap

View File

@ -412,7 +412,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
&ADD_NS(native_layer_norm)>::type::call)));
KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32)
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, OptionalIntArrayRef, bool), fp32)
KERNEL(ADD_NS(nuclear_norm), "nuclear_norm", Tensor (const Tensor &, bool), fp32)
KERNEL(ADD_NS(nuclear_norm), "nuclear_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
KERNEL(ADD_NS(cosine_similarity), "cosine_similarity", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
@ -461,7 +461,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
// norm does not implicitly promote, but be aware when adding new ops to this policy.
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional<Scalar>&, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, OptionalIntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, OptionalIntArrayRef, bool, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool, ScalarType), fp32_append_dtype)
// promote
KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&), promote)

View File

@ -2672,19 +2672,19 @@ Tensor frobenius_norm(const Tensor& self) {
return at::norm(self);
}
Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor frobenius_norm(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
TORCH_CHECK(
dim.size() <= 2,
!dim.has_value() || dim.value().size() <= 2,
"Expected at most 2 dimensions, but got ",
dim.size(),
dim.value().size(),
" dimensions instead.");
Tensor result;
if (dim.size() == 1 || dim.size() == 0) {
if (!dim.has_value() || dim.value().size() == 1 || dim.value().size() == 0) {
result = at::norm(self, 2, dim, keepdim);
} else {
auto dim_ = dim.vec();
auto dim_ = dim.value().vec();
maybe_wrap_dims(dim_, self.dim());
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim.value(), " instead");
if (self.is_complex()) {
result = at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
} else {
@ -2697,7 +2697,7 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
}
Tensor &frobenius_norm_out(const Tensor& self,
IntArrayRef dim,
OptionalIntArrayRef dim,
bool keepdim,
Tensor& result) {
auto result_ = at::native::frobenius_norm(self, dim, keepdim);

View File

@ -235,7 +235,7 @@ ScalarType get_result_or_self_value_dtype(
}
TORCH_META_FUNC2(norm, ScalarOpt_dim)
(const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) {
(const Tensor& self, const OptionalScalarRef p, at::OptionalIntArrayRef dim, bool keepdim) {
TORCH_CHECK(
at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"norm(): input dtype should be either floating point or complex. "
@ -248,7 +248,7 @@ TORCH_META_FUNC2(norm, ScalarOpt_dim)
TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype)
(const Tensor& self,
const OptionalScalarRef p,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
ScalarType dtype) {
TORCH_CHECK(
@ -282,7 +282,7 @@ TORCH_META_FUNC(aminmax)
}
TORCH_META_FUNC(amax)
(const Tensor& self, IntArrayRef dim, bool keepdim) {
(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim) {
auto maybe_result = maybe_get_output();
if (maybe_result.defined()) {
TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
@ -296,7 +296,7 @@ TORCH_META_FUNC(amax)
}
TORCH_META_FUNC(amin)
(const Tensor& self, IntArrayRef dim, bool keepdim) {
(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim) {
auto maybe_result = maybe_get_output();
if (maybe_result.defined()) {
TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
@ -1359,7 +1359,7 @@ Tensor& special_logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim
void impl_func_norm(
const Tensor& self,
const OptionalScalarRef& opt_p,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
optional<ScalarType> opt_dtype,
const Tensor& result) {
@ -1396,7 +1396,7 @@ void impl_func_norm(
TORCH_IMPL_FUNC(norm_out)
(const Tensor& self,
const OptionalScalarRef p,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
const Tensor& result) {
impl_func_norm(self, p, dim, keepdim, c10::nullopt, result);
@ -1405,7 +1405,7 @@ TORCH_IMPL_FUNC(norm_out)
TORCH_IMPL_FUNC(norm_dtype_out)
(const Tensor& self,
const OptionalScalarRef p,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
ScalarType dtype,
const Tensor& result) {
@ -1415,7 +1415,7 @@ TORCH_IMPL_FUNC(norm_dtype_out)
Tensor sparse_norm(
const Tensor& self,
const optional<Scalar>& p,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim) {
return at::native_norm(self, p, dim, keepdim, c10::nullopt);
}
@ -1423,7 +1423,7 @@ Tensor sparse_norm(
Tensor sparse_dtype_norm(
const Tensor& self,
const optional<Scalar>& p,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
ScalarType dtype) {
return at::native_norm(self, p, dim, keepdim, dtype);
@ -1488,7 +1488,7 @@ TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
allany_impl<0>(self, result, {}, false, or_stub);
}
TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
TORCH_IMPL_FUNC(amin_out) (const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
auto iter =
meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
if (iter.numel() != 0) {
@ -1496,7 +1496,7 @@ TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, co
}
}
TORCH_IMPL_FUNC(amax_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
TORCH_IMPL_FUNC(amax_out) (const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
auto iter =
meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
if (iter.numel() != 0) {

View File

@ -275,7 +275,9 @@ static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const c
}
}
static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
static void zero_numel_check_dims(const Tensor& self, const at::OptionalIntArrayRef opt_dim, const char *fn_name) {
if (opt_dim.has_value()) {
const IntArrayRef dim = opt_dim.value();
TORCH_CHECK(
!dim.empty(),
fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
@ -284,6 +286,7 @@ static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, con
zero_numel_check_dims(self, d, fn_name);
}
}
}
static std::vector<int64_t> get_zero_numel_tensor_size(
const Tensor& self,

View File

@ -1956,12 +1956,12 @@ int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
return num_nonzero;
}
Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){
Tensor count_nonzero_cuda(const Tensor& self, OptionalIntArrayRef dims){
return (self != 0).sum(dims);
}
Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
if (dims.size() > 0) {
Tensor count_nonzero_cpu(const Tensor& self, OptionalIntArrayRef dims){
if (dims.has_value() && dims.value().size() > 0) {
return (self != 0).sum(dims);
}

View File

@ -315,7 +315,7 @@ inline ScalarType get_dtype_from_self(
TORCH_IMPL_FUNC(amax_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
const Tensor& output_t) {
@ -324,7 +324,7 @@ TORCH_IMPL_FUNC(amax_out_mps)
TORCH_IMPL_FUNC(amin_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim,
const Tensor& output_t) {
@ -354,7 +354,7 @@ Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
}
Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){
Tensor count_nonzero_mps(const Tensor& self, OptionalIntArrayRef dims){
NSMutableArray<NSNumber*> *axes = nil;
NSMutableArray<NSNumber*> *apparent_input_shape = nil;
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
@ -395,7 +395,7 @@ TORCH_IMPL_FUNC(mean_out_mps)
TORCH_IMPL_FUNC(norm_out_mps)
(const Tensor& input_tensor,
const OptionalScalarRef opt_p,
IntArrayRef dim,
OptionalIntArrayRef opt_dim,
bool keepdim,
const Tensor& output_t)
{
@ -406,11 +406,14 @@ TORCH_IMPL_FUNC(norm_out_mps)
IntArrayRef input_shape = input_t.sizes();
if (opt_dim.has_value()) {
IntArrayRef dim = opt_dim.value();
for(int i = 0; i < dim.size(); i++) {
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
TORCH_CHECK(wrap_dim < input_shape.size(),
"norm_out_mps: reduction dim must be in the range of input shape")
}
}
namespace native_mps = at::native::mps;
using CachedGraph = native_mps::MPSUnaryCachedGraph;
@ -424,7 +427,7 @@ TORCH_IMPL_FUNC(norm_out_mps)
bool pIsNegInf = (p == -numeric_limits<double>::infinity());
int64_t num_input_dims = input_shape.size();
int64_t num_reduce_dims = dim.size();
int64_t num_reduce_dims = opt_dim.has_value() ? opt_dim.value().size() : 0;
int64_t num_output_dims;
// For output shape calculation, assume that keepdim is true
@ -434,7 +437,7 @@ TORCH_IMPL_FUNC(norm_out_mps)
// Reduction axes
NSMutableArray<NSNumber *> *axes;
set_axes(axes, num_reduce_dims, dim, input_shape.size());
set_axes(axes, num_reduce_dims, opt_dim, input_shape.size());
set_apparent_shapes(apparent_output_shape,
apparent_input_shape,

View File

@ -1574,7 +1574,7 @@
- func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
- func: count_nonzero.dim_IntList(Tensor self, int[]? dim) -> Tensor
variants: function, method
dispatch:
CPU: count_nonzero_cpu
@ -3311,11 +3311,11 @@
device_check: NoCheck
device_guard: False
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
- func: amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amax.out
- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
- func: amax.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: amax_out
@ -3485,11 +3485,11 @@
- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
- func: amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amin.out
- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
- func: amin.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: amin_out
@ -5671,7 +5671,7 @@
SparseCPU, SparseCUDA: norm_sparse
autogen: native_norm.out
- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, ScalarType? dtype) -> Tensor
dispatch:
SparseCPU, SparseCUDA: norm_sparse
autogen: native_norm.ScalarOpt_dim_dtype_out
@ -5768,27 +5768,27 @@
CompositeExplicitAutograd: norm
autogen: norm.Scalar_out
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype) -> Tensor
structured_delegate: norm.dtype_out
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
SparseCPU, SparseCUDA: sparse_dtype_norm
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False) -> Tensor
structured_delegate: norm.out
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
SparseCPU, SparseCUDA: sparse_norm
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
- func: norm.dtype_out(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
structured: True
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: norm_dtype_out
- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
- func: norm.out(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
device_check: NoCheck # TensorIterator
dispatch:
@ -5824,11 +5824,11 @@
variants: function
# Deprecated (v.1.12)
- func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
- func: frobenius_norm.dim(Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
variants: function
# Deprecated (v.1.12)
- func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
- func: frobenius_norm.out(Tensor self, int[1]? dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
# Deprecated (v.1.12)

View File

@ -372,14 +372,14 @@ Tensor norm_sparse(const SparseTensor& self, const Scalar& p) {
return norm_sparse(self, p, IntArrayRef{}, false, c10::nullopt);
}
Tensor norm_sparse(const SparseTensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, optional<ScalarType> dtype) {
Tensor norm_sparse(const SparseTensor& self, const optional<Scalar>& p, OptionalIntArrayRef dim, bool keepdim, optional<ScalarType> dtype) {
AT_ASSERT(self.is_sparse());
if (dim.size() > 0) {
if (dim.has_value() && dim.value().size() > 0) {
// Only full reductions are supported, so check if that is the case
int64_t ndim = self.dim();
bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.size();
bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.value().size();
if (passed_full_reduction_check) {
auto dim_ = dim.vec();
auto dim_ = dim.value().vec();
maybe_wrap_dims(dim_, ndim);
std::vector<bool> dims_check(ndim, false);
// Need to check for duplicates, and fail if any are found

View File

@ -93,6 +93,10 @@ ALLOW_LIST = [
("aten::_linalg_inv_out_helper.out", datetime.date(2022, 10, 1)),
("aten::_linalg_inv_out_helper_", datetime.date(2022, 10, 1)),
("aten::_linalg_inv_out_helper", datetime.date(2022, 10, 1)),
("aten::amax", datetime.date(2022, 10, 25)),
("aten::amax.out", datetime.date(2022, 10, 25)),
("aten::amin", datetime.date(2022, 10, 25)),
("aten::amin.out", datetime.date(2022, 10, 25)),
("aten::solve", datetime.date(9999, 1, 1)),
("aten::solve.solution", datetime.date(9999, 1, 1)),
("aten::_solve_helper", datetime.date(9999, 1, 1)),

View File

@ -1190,6 +1190,8 @@ class TestNamedTensor(TestCase):
'var_mean',
'nanmean',
'nansum',
'amax',
'amin',
]
if op.__name__ in ops_support_dim_none:
check_output(op(t, None), [])

View File

@ -465,9 +465,10 @@
self: grad * self.sinh().conj()
result: auto_element_wise
- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
- name: count_nonzero.dim_IntList(Tensor self, int[]? dim) -> Tensor
output_differentiability: [False]
# TODO: Should probably remove this overload
- name: count_nonzero(Tensor self, int? dim=None) -> Tensor
output_differentiability: [False]
@ -1105,11 +1106,11 @@
other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0)
result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
- name: amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
- name: amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
@ -1190,7 +1191,7 @@
self: norm_backward(grad, self, p, result)
result: norm_jvp(self_p, self_t, p, result)
- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False) -> Tensor
self: norm_backward(grad, self, p, result, dim, keepdim)
result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
@ -1198,7 +1199,7 @@
self: norm_backward(grad, self.to(grad.scalar_type()), p, result)
result: norm_jvp(self_p, self_t, p, result)
- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype) -> Tensor
self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim)
result: norm_jvp(self_p, self_t, p, result, dim, keepdim)

View File

@ -6698,7 +6698,7 @@ dimension(s) :attr:`dim`.
Args:
{input}
{dim}
{opt_dim}
{keepdim}
Keyword args:
@ -7320,7 +7320,7 @@ dimension(s) :attr:`dim`.
Args:
{input}
{dim}
{opt_dim}
{keepdim}
Keyword args:

View File

@ -160,13 +160,21 @@ Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
Tensor restore_reduced_dims(
const Tensor& output,
IntArrayRef dims,
at::OptionalIntArrayRef opt_dims,
bool keepdim) {
if (keepdim) {
return output;
}
int64_t total_dims = output.dim() + dims.size();
int64_t total_dims;
if (opt_dims.has_value()) {
total_dims = output.dim() + opt_dims.value().size();
} else {
total_dims = 0;
}
std::vector<int64_t> target_shape(total_dims, 0);
if (opt_dims.has_value()) {
IntArrayRef dims = opt_dims.value();
for (int64_t i : dims) {
if (i < 0) {
i = total_dims + i;
@ -179,13 +187,14 @@ Tensor restore_reduced_dims(
j++;
target_shape[j++] = i;
}
}
return output.reshape(target_shape);
}
Tensor scale_grad_by_count(
const Tensor& grad,
const Tensor& mask,
IntArrayRef dims) {
at::OptionalIntArrayRef dims) {
return (grad / mask.sum(dims, true)) * mask;
}
@ -193,7 +202,7 @@ Tensor amaxamin_jvp(
const Tensor& x,
const Tensor& dx,
const Tensor& result,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim) {
auto mask = x == restore_reduced_dims(result, dim, keepdim);
return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim);
@ -228,7 +237,7 @@ Tensor norm_backward(
const Tensor& self,
const optional<Scalar>& p_,
Tensor norm,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim) {
// NB: We mask fill the NaNs in the output to be zero but still do float
// division
@ -282,7 +291,7 @@ Tensor norm_jvp(
const Tensor& self_t,
const optional<Scalar>& p_,
Tensor norm,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim) {
// NB: currently norm_jvp is also reused for dist's jvp (which haas two
// differentiable inputs)

View File

@ -58,12 +58,12 @@ at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s);
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim);
Tensor restore_reduced_dims(
const Tensor& output,
IntArrayRef dims,
at::OptionalIntArrayRef dims,
bool keepdim);
Tensor scale_grad_by_count(
const Tensor& grad,
const Tensor& mask,
IntArrayRef dims);
at::OptionalIntArrayRef dims);
at::Tensor norm_backward(
const at::Tensor& grad,
const at::Tensor& self,
@ -74,14 +74,14 @@ at::Tensor norm_backward(
const at::Tensor& self,
const optional<at::Scalar>& p_,
at::Tensor norm,
at::IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim);
Tensor norm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const optional<Scalar>& p_,
Tensor norm,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim);
Tensor norm_jvp(
const Tensor& grad,
@ -754,7 +754,7 @@ Tensor amaxamin_jvp(
const Tensor& x,
const Tensor& dx,
const Tensor& result,
IntArrayRef dim,
at::OptionalIntArrayRef dim,
bool keepdim);
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
const Tensor& input,

View File

@ -3043,8 +3043,8 @@ class IrParser {
{
std::array<const char*, kNumAminAmaxOps> BinaryFloatOp = {
"aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor",
"aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"};
"aten::amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor",
"aten::amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor"};
for (auto signature : BinaryFloatOp) {
auto ptr_op = getOperatorForLiteral(signature);
REGISTER_PARSE_RULE(
@ -4008,11 +4008,11 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
static auto amax_schema =
getOperatorForLiteral(
"aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor")
"aten::amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor")
->schema();
static auto amin_schema =
getOperatorForLiteral(
"aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor")
"aten::amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor")
->schema();
if (node->matches(amax_schema) || node->matches(amin_schema)) {
switch (offset) {

View File

@ -2502,11 +2502,8 @@ def error_inputs_aminmax_amax_amin(op_info, device, **kwargs):
# Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
shape = (S, 0, S)
err_msg_amax_amin = "reduction"
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
elif op_info.name in ['aminmax']:
if op_info.name in ['aminmax']:
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
# Error Inputs for tensors with more than 64 dimension