mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Update amax/amin/norm/count_nonzero
signatures with int[*]? dim
(#83300)
Changes `dim` arg to use `int[*]?` type for the following functions in `native_funcitons.yaml`: * `amax` * `amin` * `norm` * `frobenius_norm` * `native_norm` * `count_nonzero` Part of #29137 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83300 Approved by: https://github.com/ngimel, https://github.com/albanD, https://github.com/kulinseth
This commit is contained in:
committed by
PyTorch MergeBot
parent
80b8886223
commit
1c0f0b33a0
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
76c65b13280cd5782ace8050df45564ef17891f9
|
||||
ec3fd75b2d4bbcf62917d0818378aa3a7d0edcb7
|
||||
|
@ -79,6 +79,17 @@ inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
|
||||
return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
|
||||
}
|
||||
|
||||
inline void maybe_wrap_dims(
|
||||
at::OptionalIntArrayRef opt_dims,
|
||||
int64_t dim_post_expr) {
|
||||
if (opt_dims.has_value()) {
|
||||
at::IntArrayRef dims = opt_dims.value();
|
||||
// TODO: This const_cast is probably not a good idea...
|
||||
int64_t* dims_ptr = const_cast<int64_t*>(dims.data());
|
||||
maybe_wrap_dims_n(dims_ptr, dims.size(), dim_post_expr);
|
||||
}
|
||||
}
|
||||
|
||||
// previously, size [0] tensors were the only possible empty tensors; thus, it
|
||||
// wasn't possible to cat empty tensors unless all the other tensors were
|
||||
// 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
|
||||
|
@ -412,7 +412,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||
&ADD_NS(native_layer_norm)>::type::call)));
|
||||
KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
|
||||
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32)
|
||||
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
|
||||
KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, OptionalIntArrayRef, bool), fp32)
|
||||
KERNEL(ADD_NS(nuclear_norm), "nuclear_norm", Tensor (const Tensor &, bool), fp32)
|
||||
KERNEL(ADD_NS(nuclear_norm), "nuclear_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
|
||||
KERNEL(ADD_NS(cosine_similarity), "cosine_similarity", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
|
||||
@ -461,7 +461,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
|
||||
// norm does not implicitly promote, but be aware when adding new ops to this policy.
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional<Scalar>&, ScalarType), fp32_append_dtype)
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, IntArrayRef, bool, ScalarType), fp32_append_dtype)
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, OptionalIntArrayRef, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, OptionalIntArrayRef, bool, ScalarType), fp32_append_dtype)
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool), Tensor (const Tensor &, const c10::optional<Scalar>&, DimnameList, bool, ScalarType), fp32_append_dtype)
|
||||
// promote
|
||||
KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&), promote)
|
||||
|
@ -2672,19 +2672,19 @@ Tensor frobenius_norm(const Tensor& self) {
|
||||
return at::norm(self);
|
||||
}
|
||||
|
||||
Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
|
||||
Tensor frobenius_norm(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
|
||||
TORCH_CHECK(
|
||||
dim.size() <= 2,
|
||||
!dim.has_value() || dim.value().size() <= 2,
|
||||
"Expected at most 2 dimensions, but got ",
|
||||
dim.size(),
|
||||
dim.value().size(),
|
||||
" dimensions instead.");
|
||||
Tensor result;
|
||||
if (dim.size() == 1 || dim.size() == 0) {
|
||||
if (!dim.has_value() || dim.value().size() == 1 || dim.value().size() == 0) {
|
||||
result = at::norm(self, 2, dim, keepdim);
|
||||
} else {
|
||||
auto dim_ = dim.vec();
|
||||
auto dim_ = dim.value().vec();
|
||||
maybe_wrap_dims(dim_, self.dim());
|
||||
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
|
||||
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim.value(), " instead");
|
||||
if (self.is_complex()) {
|
||||
result = at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
|
||||
} else {
|
||||
@ -2697,7 +2697,7 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
|
||||
}
|
||||
|
||||
Tensor &frobenius_norm_out(const Tensor& self,
|
||||
IntArrayRef dim,
|
||||
OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
Tensor& result) {
|
||||
auto result_ = at::native::frobenius_norm(self, dim, keepdim);
|
||||
|
@ -235,7 +235,7 @@ ScalarType get_result_or_self_value_dtype(
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(norm, ScalarOpt_dim)
|
||||
(const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) {
|
||||
(const Tensor& self, const OptionalScalarRef p, at::OptionalIntArrayRef dim, bool keepdim) {
|
||||
TORCH_CHECK(
|
||||
at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
|
||||
"norm(): input dtype should be either floating point or complex. "
|
||||
@ -248,7 +248,7 @@ TORCH_META_FUNC2(norm, ScalarOpt_dim)
|
||||
TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype)
|
||||
(const Tensor& self,
|
||||
const OptionalScalarRef p,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
ScalarType dtype) {
|
||||
TORCH_CHECK(
|
||||
@ -282,7 +282,7 @@ TORCH_META_FUNC(aminmax)
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(amax)
|
||||
(const Tensor& self, IntArrayRef dim, bool keepdim) {
|
||||
(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim) {
|
||||
auto maybe_result = maybe_get_output();
|
||||
if (maybe_result.defined()) {
|
||||
TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
|
||||
@ -296,7 +296,7 @@ TORCH_META_FUNC(amax)
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(amin)
|
||||
(const Tensor& self, IntArrayRef dim, bool keepdim) {
|
||||
(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim) {
|
||||
auto maybe_result = maybe_get_output();
|
||||
if (maybe_result.defined()) {
|
||||
TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
|
||||
@ -1359,7 +1359,7 @@ Tensor& special_logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim
|
||||
void impl_func_norm(
|
||||
const Tensor& self,
|
||||
const OptionalScalarRef& opt_p,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
optional<ScalarType> opt_dtype,
|
||||
const Tensor& result) {
|
||||
@ -1396,7 +1396,7 @@ void impl_func_norm(
|
||||
TORCH_IMPL_FUNC(norm_out)
|
||||
(const Tensor& self,
|
||||
const OptionalScalarRef p,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
const Tensor& result) {
|
||||
impl_func_norm(self, p, dim, keepdim, c10::nullopt, result);
|
||||
@ -1405,7 +1405,7 @@ TORCH_IMPL_FUNC(norm_out)
|
||||
TORCH_IMPL_FUNC(norm_dtype_out)
|
||||
(const Tensor& self,
|
||||
const OptionalScalarRef p,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
ScalarType dtype,
|
||||
const Tensor& result) {
|
||||
@ -1415,7 +1415,7 @@ TORCH_IMPL_FUNC(norm_dtype_out)
|
||||
Tensor sparse_norm(
|
||||
const Tensor& self,
|
||||
const optional<Scalar>& p,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim) {
|
||||
return at::native_norm(self, p, dim, keepdim, c10::nullopt);
|
||||
}
|
||||
@ -1423,7 +1423,7 @@ Tensor sparse_norm(
|
||||
Tensor sparse_dtype_norm(
|
||||
const Tensor& self,
|
||||
const optional<Scalar>& p,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
ScalarType dtype) {
|
||||
return at::native_norm(self, p, dim, keepdim, dtype);
|
||||
@ -1488,7 +1488,7 @@ TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
|
||||
allany_impl<0>(self, result, {}, false, or_stub);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(amin_out) (const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
|
||||
auto iter =
|
||||
meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
|
||||
if (iter.numel() != 0) {
|
||||
@ -1496,7 +1496,7 @@ TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, co
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(amax_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(amax_out) (const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
|
||||
auto iter =
|
||||
meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
|
||||
if (iter.numel() != 0) {
|
||||
|
@ -275,7 +275,9 @@ static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const c
|
||||
}
|
||||
}
|
||||
|
||||
static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
|
||||
static void zero_numel_check_dims(const Tensor& self, const at::OptionalIntArrayRef opt_dim, const char *fn_name) {
|
||||
if (opt_dim.has_value()) {
|
||||
const IntArrayRef dim = opt_dim.value();
|
||||
TORCH_CHECK(
|
||||
!dim.empty(),
|
||||
fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
|
||||
@ -284,6 +286,7 @@ static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, con
|
||||
zero_numel_check_dims(self, d, fn_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int64_t> get_zero_numel_tensor_size(
|
||||
const Tensor& self,
|
||||
|
@ -1956,12 +1956,12 @@ int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
|
||||
return num_nonzero;
|
||||
}
|
||||
|
||||
Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){
|
||||
Tensor count_nonzero_cuda(const Tensor& self, OptionalIntArrayRef dims){
|
||||
return (self != 0).sum(dims);
|
||||
}
|
||||
|
||||
Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
|
||||
if (dims.size() > 0) {
|
||||
Tensor count_nonzero_cpu(const Tensor& self, OptionalIntArrayRef dims){
|
||||
if (dims.has_value() && dims.value().size() > 0) {
|
||||
return (self != 0).sum(dims);
|
||||
}
|
||||
|
||||
|
@ -315,7 +315,7 @@ inline ScalarType get_dtype_from_self(
|
||||
|
||||
TORCH_IMPL_FUNC(amax_out_mps)
|
||||
(const Tensor& input_t,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
const Tensor& output_t) {
|
||||
|
||||
@ -324,7 +324,7 @@ TORCH_IMPL_FUNC(amax_out_mps)
|
||||
|
||||
TORCH_IMPL_FUNC(amin_out_mps)
|
||||
(const Tensor& input_t,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim,
|
||||
const Tensor& output_t) {
|
||||
|
||||
@ -354,7 +354,7 @@ Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
|
||||
}
|
||||
|
||||
|
||||
Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){
|
||||
Tensor count_nonzero_mps(const Tensor& self, OptionalIntArrayRef dims){
|
||||
NSMutableArray<NSNumber*> *axes = nil;
|
||||
NSMutableArray<NSNumber*> *apparent_input_shape = nil;
|
||||
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
|
||||
@ -395,7 +395,7 @@ TORCH_IMPL_FUNC(mean_out_mps)
|
||||
TORCH_IMPL_FUNC(norm_out_mps)
|
||||
(const Tensor& input_tensor,
|
||||
const OptionalScalarRef opt_p,
|
||||
IntArrayRef dim,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
bool keepdim,
|
||||
const Tensor& output_t)
|
||||
{
|
||||
@ -406,11 +406,14 @@ TORCH_IMPL_FUNC(norm_out_mps)
|
||||
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
|
||||
if (opt_dim.has_value()) {
|
||||
IntArrayRef dim = opt_dim.value();
|
||||
for(int i = 0; i < dim.size(); i++) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
|
||||
TORCH_CHECK(wrap_dim < input_shape.size(),
|
||||
"norm_out_mps: reduction dim must be in the range of input shape")
|
||||
}
|
||||
}
|
||||
namespace native_mps = at::native::mps;
|
||||
|
||||
using CachedGraph = native_mps::MPSUnaryCachedGraph;
|
||||
@ -424,7 +427,7 @@ TORCH_IMPL_FUNC(norm_out_mps)
|
||||
bool pIsNegInf = (p == -numeric_limits<double>::infinity());
|
||||
|
||||
int64_t num_input_dims = input_shape.size();
|
||||
int64_t num_reduce_dims = dim.size();
|
||||
int64_t num_reduce_dims = opt_dim.has_value() ? opt_dim.value().size() : 0;
|
||||
int64_t num_output_dims;
|
||||
|
||||
// For output shape calculation, assume that keepdim is true
|
||||
@ -434,7 +437,7 @@ TORCH_IMPL_FUNC(norm_out_mps)
|
||||
|
||||
// Reduction axes
|
||||
NSMutableArray<NSNumber *> *axes;
|
||||
set_axes(axes, num_reduce_dims, dim, input_shape.size());
|
||||
set_axes(axes, num_reduce_dims, opt_dim, input_shape.size());
|
||||
|
||||
set_apparent_shapes(apparent_output_shape,
|
||||
apparent_input_shape,
|
||||
|
@ -1574,7 +1574,7 @@
|
||||
|
||||
- func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
|
||||
|
||||
- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
|
||||
- func: count_nonzero.dim_IntList(Tensor self, int[]? dim) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: count_nonzero_cpu
|
||||
@ -3311,11 +3311,11 @@
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
|
||||
- func: amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
|
||||
variants: function, method
|
||||
structured_delegate: amax.out
|
||||
|
||||
- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: amax.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU, CUDA: amax_out
|
||||
@ -3485,11 +3485,11 @@
|
||||
- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
||||
- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
|
||||
- func: amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
|
||||
variants: function, method
|
||||
structured_delegate: amin.out
|
||||
|
||||
- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: amin.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU, CUDA: amin_out
|
||||
@ -5671,7 +5671,7 @@
|
||||
SparseCPU, SparseCUDA: norm_sparse
|
||||
autogen: native_norm.out
|
||||
|
||||
- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
|
||||
- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, ScalarType? dtype) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: norm_sparse
|
||||
autogen: native_norm.ScalarOpt_dim_dtype_out
|
||||
@ -5768,27 +5768,27 @@
|
||||
CompositeExplicitAutograd: norm
|
||||
autogen: norm.Scalar_out
|
||||
|
||||
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
|
||||
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype) -> Tensor
|
||||
structured_delegate: norm.dtype_out
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_dtype_norm
|
||||
|
||||
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
|
||||
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False) -> Tensor
|
||||
structured_delegate: norm.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_norm
|
||||
|
||||
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: norm.dtype_out(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA: norm_dtype_out
|
||||
|
||||
- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: norm.out(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
@ -5824,11 +5824,11 @@
|
||||
variants: function
|
||||
|
||||
# Deprecated (v.1.12)
|
||||
- func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
|
||||
- func: frobenius_norm.dim(Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# Deprecated (v.1.12)
|
||||
- func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: frobenius_norm.out(Tensor self, int[1]? dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
|
||||
# Deprecated (v.1.12)
|
||||
|
@ -372,14 +372,14 @@ Tensor norm_sparse(const SparseTensor& self, const Scalar& p) {
|
||||
return norm_sparse(self, p, IntArrayRef{}, false, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor norm_sparse(const SparseTensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, optional<ScalarType> dtype) {
|
||||
Tensor norm_sparse(const SparseTensor& self, const optional<Scalar>& p, OptionalIntArrayRef dim, bool keepdim, optional<ScalarType> dtype) {
|
||||
AT_ASSERT(self.is_sparse());
|
||||
if (dim.size() > 0) {
|
||||
if (dim.has_value() && dim.value().size() > 0) {
|
||||
// Only full reductions are supported, so check if that is the case
|
||||
int64_t ndim = self.dim();
|
||||
bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.size();
|
||||
bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.value().size();
|
||||
if (passed_full_reduction_check) {
|
||||
auto dim_ = dim.vec();
|
||||
auto dim_ = dim.value().vec();
|
||||
maybe_wrap_dims(dim_, ndim);
|
||||
std::vector<bool> dims_check(ndim, false);
|
||||
// Need to check for duplicates, and fail if any are found
|
||||
|
@ -93,6 +93,10 @@ ALLOW_LIST = [
|
||||
("aten::_linalg_inv_out_helper.out", datetime.date(2022, 10, 1)),
|
||||
("aten::_linalg_inv_out_helper_", datetime.date(2022, 10, 1)),
|
||||
("aten::_linalg_inv_out_helper", datetime.date(2022, 10, 1)),
|
||||
("aten::amax", datetime.date(2022, 10, 25)),
|
||||
("aten::amax.out", datetime.date(2022, 10, 25)),
|
||||
("aten::amin", datetime.date(2022, 10, 25)),
|
||||
("aten::amin.out", datetime.date(2022, 10, 25)),
|
||||
("aten::solve", datetime.date(9999, 1, 1)),
|
||||
("aten::solve.solution", datetime.date(9999, 1, 1)),
|
||||
("aten::_solve_helper", datetime.date(9999, 1, 1)),
|
||||
|
@ -1190,6 +1190,8 @@ class TestNamedTensor(TestCase):
|
||||
'var_mean',
|
||||
'nanmean',
|
||||
'nansum',
|
||||
'amax',
|
||||
'amin',
|
||||
]
|
||||
if op.__name__ in ops_support_dim_none:
|
||||
check_output(op(t, None), [])
|
||||
|
@ -465,9 +465,10 @@
|
||||
self: grad * self.sinh().conj()
|
||||
result: auto_element_wise
|
||||
|
||||
- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
|
||||
- name: count_nonzero.dim_IntList(Tensor self, int[]? dim) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
# TODO: Should probably remove this overload
|
||||
- name: count_nonzero(Tensor self, int? dim=None) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
@ -1105,11 +1106,11 @@
|
||||
other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0)
|
||||
result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
|
||||
|
||||
- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
|
||||
- name: amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
|
||||
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
|
||||
result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
|
||||
|
||||
- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
|
||||
- name: amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
|
||||
self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
|
||||
result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
|
||||
|
||||
@ -1190,7 +1191,7 @@
|
||||
self: norm_backward(grad, self, p, result)
|
||||
result: norm_jvp(self_p, self_t, p, result)
|
||||
|
||||
- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
|
||||
- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False) -> Tensor
|
||||
self: norm_backward(grad, self, p, result, dim, keepdim)
|
||||
result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
|
||||
|
||||
@ -1198,7 +1199,7 @@
|
||||
self: norm_backward(grad, self.to(grad.scalar_type()), p, result)
|
||||
result: norm_jvp(self_p, self_t, p, result)
|
||||
|
||||
- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
|
||||
- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype) -> Tensor
|
||||
self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim)
|
||||
result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
|
||||
|
||||
|
@ -6698,7 +6698,7 @@ dimension(s) :attr:`dim`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
{dim}
|
||||
{opt_dim}
|
||||
{keepdim}
|
||||
|
||||
Keyword args:
|
||||
@ -7320,7 +7320,7 @@ dimension(s) :attr:`dim`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
{dim}
|
||||
{opt_dim}
|
||||
{keepdim}
|
||||
|
||||
Keyword args:
|
||||
|
@ -160,13 +160,21 @@ Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
|
||||
|
||||
Tensor restore_reduced_dims(
|
||||
const Tensor& output,
|
||||
IntArrayRef dims,
|
||||
at::OptionalIntArrayRef opt_dims,
|
||||
bool keepdim) {
|
||||
if (keepdim) {
|
||||
return output;
|
||||
}
|
||||
int64_t total_dims = output.dim() + dims.size();
|
||||
int64_t total_dims;
|
||||
|
||||
if (opt_dims.has_value()) {
|
||||
total_dims = output.dim() + opt_dims.value().size();
|
||||
} else {
|
||||
total_dims = 0;
|
||||
}
|
||||
std::vector<int64_t> target_shape(total_dims, 0);
|
||||
if (opt_dims.has_value()) {
|
||||
IntArrayRef dims = opt_dims.value();
|
||||
for (int64_t i : dims) {
|
||||
if (i < 0) {
|
||||
i = total_dims + i;
|
||||
@ -179,13 +187,14 @@ Tensor restore_reduced_dims(
|
||||
j++;
|
||||
target_shape[j++] = i;
|
||||
}
|
||||
}
|
||||
return output.reshape(target_shape);
|
||||
}
|
||||
|
||||
Tensor scale_grad_by_count(
|
||||
const Tensor& grad,
|
||||
const Tensor& mask,
|
||||
IntArrayRef dims) {
|
||||
at::OptionalIntArrayRef dims) {
|
||||
return (grad / mask.sum(dims, true)) * mask;
|
||||
}
|
||||
|
||||
@ -193,7 +202,7 @@ Tensor amaxamin_jvp(
|
||||
const Tensor& x,
|
||||
const Tensor& dx,
|
||||
const Tensor& result,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim) {
|
||||
auto mask = x == restore_reduced_dims(result, dim, keepdim);
|
||||
return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim);
|
||||
@ -228,7 +237,7 @@ Tensor norm_backward(
|
||||
const Tensor& self,
|
||||
const optional<Scalar>& p_,
|
||||
Tensor norm,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim) {
|
||||
// NB: We mask fill the NaNs in the output to be zero but still do float
|
||||
// division
|
||||
@ -282,7 +291,7 @@ Tensor norm_jvp(
|
||||
const Tensor& self_t,
|
||||
const optional<Scalar>& p_,
|
||||
Tensor norm,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim) {
|
||||
// NB: currently norm_jvp is also reused for dist's jvp (which haas two
|
||||
// differentiable inputs)
|
||||
|
@ -58,12 +58,12 @@ at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s);
|
||||
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim);
|
||||
Tensor restore_reduced_dims(
|
||||
const Tensor& output,
|
||||
IntArrayRef dims,
|
||||
at::OptionalIntArrayRef dims,
|
||||
bool keepdim);
|
||||
Tensor scale_grad_by_count(
|
||||
const Tensor& grad,
|
||||
const Tensor& mask,
|
||||
IntArrayRef dims);
|
||||
at::OptionalIntArrayRef dims);
|
||||
at::Tensor norm_backward(
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& self,
|
||||
@ -74,14 +74,14 @@ at::Tensor norm_backward(
|
||||
const at::Tensor& self,
|
||||
const optional<at::Scalar>& p_,
|
||||
at::Tensor norm,
|
||||
at::IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim);
|
||||
Tensor norm_jvp(
|
||||
const Tensor& self_p,
|
||||
const Tensor& self_t,
|
||||
const optional<Scalar>& p_,
|
||||
Tensor norm,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim);
|
||||
Tensor norm_jvp(
|
||||
const Tensor& grad,
|
||||
@ -754,7 +754,7 @@ Tensor amaxamin_jvp(
|
||||
const Tensor& x,
|
||||
const Tensor& dx,
|
||||
const Tensor& result,
|
||||
IntArrayRef dim,
|
||||
at::OptionalIntArrayRef dim,
|
||||
bool keepdim);
|
||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
||||
const Tensor& input,
|
||||
|
@ -3043,8 +3043,8 @@ class IrParser {
|
||||
|
||||
{
|
||||
std::array<const char*, kNumAminAmaxOps> BinaryFloatOp = {
|
||||
"aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor",
|
||||
"aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"};
|
||||
"aten::amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor",
|
||||
"aten::amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor"};
|
||||
for (auto signature : BinaryFloatOp) {
|
||||
auto ptr_op = getOperatorForLiteral(signature);
|
||||
REGISTER_PARSE_RULE(
|
||||
@ -4008,11 +4008,11 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
|
||||
|
||||
static auto amax_schema =
|
||||
getOperatorForLiteral(
|
||||
"aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor")
|
||||
"aten::amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor")
|
||||
->schema();
|
||||
static auto amin_schema =
|
||||
getOperatorForLiteral(
|
||||
"aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor")
|
||||
"aten::amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor")
|
||||
->schema();
|
||||
if (node->matches(amax_schema) || node->matches(amin_schema)) {
|
||||
switch (offset) {
|
||||
|
@ -2502,11 +2502,8 @@ def error_inputs_aminmax_amax_amin(op_info, device, **kwargs):
|
||||
|
||||
# Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
|
||||
shape = (S, 0, S)
|
||||
err_msg_amax_amin = "reduction"
|
||||
err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
|
||||
if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
|
||||
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
|
||||
elif op_info.name in ['aminmax']:
|
||||
if op_info.name in ['aminmax']:
|
||||
yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
|
||||
|
||||
# Error Inputs for tensors with more than 64 dimension
|
||||
|
Reference in New Issue
Block a user