diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index cb2f7a4ea1d2..bb826539c0fd 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -76c65b13280cd5782ace8050df45564ef17891f9 +ec3fd75b2d4bbcf62917d0818378aa3a7d0edcb7 diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 1d4f45c6345e..cc0183fd6997 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -79,6 +79,17 @@ inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) { return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr); } +inline void maybe_wrap_dims( + at::OptionalIntArrayRef opt_dims, + int64_t dim_post_expr) { + if (opt_dims.has_value()) { + at::IntArrayRef dims = opt_dims.value(); + // TODO: This const_cast is probably not a good idea... + int64_t* dims_ptr = const_cast(dims.data()); + maybe_wrap_dims_n(dims_ptr, dims.size(), dim_post_expr); + } +} + // previously, size [0] tensors were the only possible empty tensors; thus, it // wasn't possible to cat empty tensors unless all the other tensors were // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 1247cf31a40a..d1deb17b46e7 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -412,7 +412,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { &ADD_NS(native_layer_norm)>::type::call))); KERNEL(ADD_NS(group_norm), "group_norm", Tensor (const Tensor &, int64_t, const c10::optional&, const c10::optional&, double, bool), fp32) KERNEL(ADD_NS(frobenius_norm), "frobenius_norm", Tensor (const Tensor &), fp32) - KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32) + KERNEL(ADD_NS(frobenius_norm), "frobenius_norm.dim", Tensor (const Tensor &, OptionalIntArrayRef, bool), fp32) KERNEL(ADD_NS(nuclear_norm), "nuclear_norm", Tensor (const Tensor &, bool), fp32) KERNEL(ADD_NS(nuclear_norm), "nuclear_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32) KERNEL(ADD_NS(cosine_similarity), "cosine_similarity", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32) @@ -461,7 +461,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { // The fp32_append_dtype wrapper overrides implicit promotion behavior. // norm does not implicitly promote, but be aware when adding new ops to this policy. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, const Scalar&), Tensor (const Tensor &, const c10::optional&, ScalarType), fp32_append_dtype) - KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional&, IntArrayRef, bool), Tensor (const Tensor &, const c10::optional&, IntArrayRef, bool, ScalarType), fp32_append_dtype) + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, const c10::optional&, OptionalIntArrayRef, bool), Tensor (const Tensor &, const c10::optional&, OptionalIntArrayRef, bool, ScalarType), fp32_append_dtype) KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, const c10::optional&, DimnameList, bool), Tensor (const Tensor &, const c10::optional&, DimnameList, bool, ScalarType), fp32_append_dtype) // promote KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&), promote) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c658d4427c97..0aaddd84a092 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2672,19 +2672,19 @@ Tensor frobenius_norm(const Tensor& self) { return at::norm(self); } -Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { +Tensor frobenius_norm(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) { TORCH_CHECK( - dim.size() <= 2, + !dim.has_value() || dim.value().size() <= 2, "Expected at most 2 dimensions, but got ", - dim.size(), + dim.value().size(), " dimensions instead."); Tensor result; - if (dim.size() == 1 || dim.size() == 0) { + if (!dim.has_value() || dim.value().size() == 1 || dim.value().size() == 0) { result = at::norm(self, 2, dim, keepdim); } else { - auto dim_ = dim.vec(); + auto dim_ = dim.value().vec(); maybe_wrap_dims(dim_, self.dim()); - TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim.value(), " instead"); if (self.is_complex()) { result = at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim)); } else { @@ -2697,7 +2697,7 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { } Tensor &frobenius_norm_out(const Tensor& self, - IntArrayRef dim, + OptionalIntArrayRef dim, bool keepdim, Tensor& result) { auto result_ = at::native::frobenius_norm(self, dim, keepdim); diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 97aeb44951d0..3167dcaa2cdb 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -235,7 +235,7 @@ ScalarType get_result_or_self_value_dtype( } TORCH_META_FUNC2(norm, ScalarOpt_dim) -(const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) { +(const Tensor& self, const OptionalScalarRef p, at::OptionalIntArrayRef dim, bool keepdim) { TORCH_CHECK( at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), "norm(): input dtype should be either floating point or complex. " @@ -248,7 +248,7 @@ TORCH_META_FUNC2(norm, ScalarOpt_dim) TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype) (const Tensor& self, const OptionalScalarRef p, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) { TORCH_CHECK( @@ -282,7 +282,7 @@ TORCH_META_FUNC(aminmax) } TORCH_META_FUNC(amax) -(const Tensor& self, IntArrayRef dim, bool keepdim) { +(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim) { auto maybe_result = maybe_get_output(); if (maybe_result.defined()) { TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ", @@ -296,7 +296,7 @@ TORCH_META_FUNC(amax) } TORCH_META_FUNC(amin) -(const Tensor& self, IntArrayRef dim, bool keepdim) { +(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim) { auto maybe_result = maybe_get_output(); if (maybe_result.defined()) { TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ", @@ -1359,7 +1359,7 @@ Tensor& special_logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim void impl_func_norm( const Tensor& self, const OptionalScalarRef& opt_p, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, optional opt_dtype, const Tensor& result) { @@ -1396,7 +1396,7 @@ void impl_func_norm( TORCH_IMPL_FUNC(norm_out) (const Tensor& self, const OptionalScalarRef p, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) { impl_func_norm(self, p, dim, keepdim, c10::nullopt, result); @@ -1405,7 +1405,7 @@ TORCH_IMPL_FUNC(norm_out) TORCH_IMPL_FUNC(norm_dtype_out) (const Tensor& self, const OptionalScalarRef p, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype, const Tensor& result) { @@ -1415,7 +1415,7 @@ TORCH_IMPL_FUNC(norm_dtype_out) Tensor sparse_norm( const Tensor& self, const optional& p, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim) { return at::native_norm(self, p, dim, keepdim, c10::nullopt); } @@ -1423,7 +1423,7 @@ Tensor sparse_norm( Tensor sparse_dtype_norm( const Tensor& self, const optional& p, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) { return at::native_norm(self, p, dim, keepdim, dtype); @@ -1488,7 +1488,7 @@ TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) { allany_impl<0>(self, result, {}, false, or_stub); } -TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) { +TORCH_IMPL_FUNC(amin_out) (const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) { auto iter = meta::make_reduction(self, result, dim, keepdim, self.scalar_type()); if (iter.numel() != 0) { @@ -1496,7 +1496,7 @@ TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, co } } -TORCH_IMPL_FUNC(amax_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) { +TORCH_IMPL_FUNC(amax_out) (const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, const Tensor& result) { auto iter = meta::make_reduction(self, result, dim, keepdim, self.scalar_type()); if (iter.numel() != 0) { diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 9db9802ea788..3daf9c6c14bc 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -275,13 +275,16 @@ static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const c } } -static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) { - TORCH_CHECK( - !dim.empty(), - fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ", - "Specify the reduction dim with the 'dim' argument."); - for (const int64_t d : dim) { - zero_numel_check_dims(self, d, fn_name); +static void zero_numel_check_dims(const Tensor& self, const at::OptionalIntArrayRef opt_dim, const char *fn_name) { + if (opt_dim.has_value()) { + const IntArrayRef dim = opt_dim.value(); + TORCH_CHECK( + !dim.empty(), + fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ", + "Specify the reduction dim with the 'dim' argument."); + for (const int64_t d : dim) { + zero_numel_check_dims(self, d, fn_name); + } } } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index d14a03c384f6..ceb8301c3dbb 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -1956,12 +1956,12 @@ int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) { return num_nonzero; } -Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){ +Tensor count_nonzero_cuda(const Tensor& self, OptionalIntArrayRef dims){ return (self != 0).sum(dims); } -Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){ - if (dims.size() > 0) { +Tensor count_nonzero_cpu(const Tensor& self, OptionalIntArrayRef dims){ + if (dims.has_value() && dims.value().size() > 0) { return (self != 0).sum(dims); } diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 56d1e0fbbaf6..43f8e6e6734c 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -315,7 +315,7 @@ inline ScalarType get_dtype_from_self( TORCH_IMPL_FUNC(amax_out_mps) (const Tensor& input_t, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, const Tensor& output_t) { @@ -324,7 +324,7 @@ TORCH_IMPL_FUNC(amax_out_mps) TORCH_IMPL_FUNC(amin_out_mps) (const Tensor& input_t, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim, const Tensor& output_t) { @@ -354,7 +354,7 @@ Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { } -Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ +Tensor count_nonzero_mps(const Tensor& self, OptionalIntArrayRef dims){ NSMutableArray *axes = nil; NSMutableArray *apparent_input_shape = nil; NSMutableArray *apparent_output_shape = nil; @@ -395,7 +395,7 @@ TORCH_IMPL_FUNC(mean_out_mps) TORCH_IMPL_FUNC(norm_out_mps) (const Tensor& input_tensor, const OptionalScalarRef opt_p, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, const Tensor& output_t) { @@ -406,10 +406,13 @@ TORCH_IMPL_FUNC(norm_out_mps) IntArrayRef input_shape = input_t.sizes(); - for(int i = 0; i < dim.size(); i++) { - auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); - TORCH_CHECK(wrap_dim < input_shape.size(), - "norm_out_mps: reduction dim must be in the range of input shape") + if (opt_dim.has_value()) { + IntArrayRef dim = opt_dim.value(); + for(int i = 0; i < dim.size(); i++) { + auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); + TORCH_CHECK(wrap_dim < input_shape.size(), + "norm_out_mps: reduction dim must be in the range of input shape") + } } namespace native_mps = at::native::mps; @@ -424,7 +427,7 @@ TORCH_IMPL_FUNC(norm_out_mps) bool pIsNegInf = (p == -numeric_limits::infinity()); int64_t num_input_dims = input_shape.size(); - int64_t num_reduce_dims = dim.size(); + int64_t num_reduce_dims = opt_dim.has_value() ? opt_dim.value().size() : 0; int64_t num_output_dims; // For output shape calculation, assume that keepdim is true @@ -434,7 +437,7 @@ TORCH_IMPL_FUNC(norm_out_mps) // Reduction axes NSMutableArray *axes; - set_axes(axes, num_reduce_dims, dim, input_shape.size()); + set_axes(axes, num_reduce_dims, opt_dim, input_shape.size()); set_apparent_shapes(apparent_output_shape, apparent_input_shape, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index dea3f3ce3be4..63e6c7baffe6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1574,7 +1574,7 @@ - func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor -- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor +- func: count_nonzero.dim_IntList(Tensor self, int[]? dim) -> Tensor variants: function, method dispatch: CPU: count_nonzero_cpu @@ -3311,11 +3311,11 @@ device_check: NoCheck device_guard: False -- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +- func: amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor variants: function, method structured_delegate: amax.out -- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +- func: amax.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU, CUDA: amax_out @@ -3485,11 +3485,11 @@ - func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) device_check: NoCheck # TensorIterator -- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +- func: amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor variants: function, method structured_delegate: amin.out -- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +- func: amin.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU, CUDA: amin_out @@ -5671,7 +5671,7 @@ SparseCPU, SparseCUDA: norm_sparse autogen: native_norm.out -- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor +- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, ScalarType? dtype) -> Tensor dispatch: SparseCPU, SparseCUDA: norm_sparse autogen: native_norm.ScalarOpt_dim_dtype_out @@ -5768,27 +5768,27 @@ CompositeExplicitAutograd: norm autogen: norm.Scalar_out -- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor +- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype) -> Tensor structured_delegate: norm.dtype_out device_check: NoCheck # TensorIterator variants: function, method dispatch: SparseCPU, SparseCUDA: sparse_dtype_norm -- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor +- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False) -> Tensor structured_delegate: norm.out device_check: NoCheck # TensorIterator variants: function, method dispatch: SparseCPU, SparseCUDA: sparse_norm -- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) +- func: norm.dtype_out(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) structured: True device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: norm_dtype_out -- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +- func: norm.out(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True device_check: NoCheck # TensorIterator dispatch: @@ -5824,11 +5824,11 @@ variants: function # Deprecated (v.1.12) -- func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor +- func: frobenius_norm.dim(Tensor self, int[1]? dim, bool keepdim=False) -> Tensor variants: function # Deprecated (v.1.12) -- func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) +- func: frobenius_norm.out(Tensor self, int[1]? dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) variants: function # Deprecated (v.1.12) diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index d49ee0bab406..1f2d31d6437c 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -372,14 +372,14 @@ Tensor norm_sparse(const SparseTensor& self, const Scalar& p) { return norm_sparse(self, p, IntArrayRef{}, false, c10::nullopt); } -Tensor norm_sparse(const SparseTensor& self, const optional& p, IntArrayRef dim, bool keepdim, optional dtype) { +Tensor norm_sparse(const SparseTensor& self, const optional& p, OptionalIntArrayRef dim, bool keepdim, optional dtype) { AT_ASSERT(self.is_sparse()); - if (dim.size() > 0) { + if (dim.has_value() && dim.value().size() > 0) { // Only full reductions are supported, so check if that is the case int64_t ndim = self.dim(); - bool passed_full_reduction_check = static_cast(ndim) == dim.size(); + bool passed_full_reduction_check = static_cast(ndim) == dim.value().size(); if (passed_full_reduction_check) { - auto dim_ = dim.vec(); + auto dim_ = dim.value().vec(); maybe_wrap_dims(dim_, ndim); std::vector dims_check(ndim, false); // Need to check for duplicates, and fail if any are found diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 1978aee14ca1..9c54b19542f2 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -93,6 +93,10 @@ ALLOW_LIST = [ ("aten::_linalg_inv_out_helper.out", datetime.date(2022, 10, 1)), ("aten::_linalg_inv_out_helper_", datetime.date(2022, 10, 1)), ("aten::_linalg_inv_out_helper", datetime.date(2022, 10, 1)), + ("aten::amax", datetime.date(2022, 10, 25)), + ("aten::amax.out", datetime.date(2022, 10, 25)), + ("aten::amin", datetime.date(2022, 10, 25)), + ("aten::amin.out", datetime.date(2022, 10, 25)), ("aten::solve", datetime.date(9999, 1, 1)), ("aten::solve.solution", datetime.date(9999, 1, 1)), ("aten::_solve_helper", datetime.date(9999, 1, 1)), diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 751a56f168e7..a0c11fcef31a 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1190,6 +1190,8 @@ class TestNamedTensor(TestCase): 'var_mean', 'nanmean', 'nansum', + 'amax', + 'amin', ] if op.__name__ in ops_support_dim_none: check_output(op(t, None), []) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ada3762c3f47..47a7c80c64a0 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -465,9 +465,10 @@ self: grad * self.sinh().conj() result: auto_element_wise -- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor +- name: count_nonzero.dim_IntList(Tensor self, int[]? dim) -> Tensor output_differentiability: [False] +# TODO: Should probably remove this overload - name: count_nonzero(Tensor self, int? dim=None) -> Tensor output_differentiability: [False] @@ -1105,11 +1106,11 @@ other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0) result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t) -- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +- name: amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) -- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor +- name: amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) @@ -1190,7 +1191,7 @@ self: norm_backward(grad, self, p, result) result: norm_jvp(self_p, self_t, p, result) -- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor +- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1]? dim, bool keepdim=False) -> Tensor self: norm_backward(grad, self, p, result, dim, keepdim) result: norm_jvp(self_p, self_t, p, result, dim, keepdim) @@ -1198,7 +1199,7 @@ self: norm_backward(grad, self.to(grad.scalar_type()), p, result) result: norm_jvp(self_p, self_t, p, result) -- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor +- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1]? dim, bool keepdim, *, ScalarType dtype) -> Tensor self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim) result: norm_jvp(self_p, self_t, p, result, dim, keepdim) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index f24be12995e1..85de969a75e6 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6698,7 +6698,7 @@ dimension(s) :attr:`dim`. Args: {input} - {dim} + {opt_dim} {keepdim} Keyword args: @@ -7320,7 +7320,7 @@ dimension(s) :attr:`dim`. Args: {input} - {dim} + {opt_dim} {keepdim} Keyword args: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 5fc723b2bdb7..5cec4b9516cd 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -160,24 +160,33 @@ Tensor handle_r_to_c(Tensor self, Tensor gradient_result) { Tensor restore_reduced_dims( const Tensor& output, - IntArrayRef dims, + at::OptionalIntArrayRef opt_dims, bool keepdim) { if (keepdim) { return output; } - int64_t total_dims = output.dim() + dims.size(); - std::vector target_shape(total_dims, 0); - for (int64_t i : dims) { - if (i < 0) { - i = total_dims + i; - } - target_shape[i] = 1; + int64_t total_dims; + + if (opt_dims.has_value()) { + total_dims = output.dim() + opt_dims.value().size(); + } else { + total_dims = 0; } - int64_t j = 0; - for (int64_t i : output.sizes()) { - while (target_shape[j] > 0) - j++; - target_shape[j++] = i; + std::vector target_shape(total_dims, 0); + if (opt_dims.has_value()) { + IntArrayRef dims = opt_dims.value(); + for (int64_t i : dims) { + if (i < 0) { + i = total_dims + i; + } + target_shape[i] = 1; + } + int64_t j = 0; + for (int64_t i : output.sizes()) { + while (target_shape[j] > 0) + j++; + target_shape[j++] = i; + } } return output.reshape(target_shape); } @@ -185,7 +194,7 @@ Tensor restore_reduced_dims( Tensor scale_grad_by_count( const Tensor& grad, const Tensor& mask, - IntArrayRef dims) { + at::OptionalIntArrayRef dims) { return (grad / mask.sum(dims, true)) * mask; } @@ -193,7 +202,7 @@ Tensor amaxamin_jvp( const Tensor& x, const Tensor& dx, const Tensor& result, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim) { auto mask = x == restore_reduced_dims(result, dim, keepdim); return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim); @@ -228,7 +237,7 @@ Tensor norm_backward( const Tensor& self, const optional& p_, Tensor norm, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim) { // NB: We mask fill the NaNs in the output to be zero but still do float // division @@ -282,7 +291,7 @@ Tensor norm_jvp( const Tensor& self_t, const optional& p_, Tensor norm, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim) { // NB: currently norm_jvp is also reused for dist's jvp (which haas two // differentiable inputs) diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 9e046c8c9c80..31e023e868d3 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -58,12 +58,12 @@ at::Tensor maybe_multiply(const at::Tensor& t, const at::Scalar& s); int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim); Tensor restore_reduced_dims( const Tensor& output, - IntArrayRef dims, + at::OptionalIntArrayRef dims, bool keepdim); Tensor scale_grad_by_count( const Tensor& grad, const Tensor& mask, - IntArrayRef dims); + at::OptionalIntArrayRef dims); at::Tensor norm_backward( const at::Tensor& grad, const at::Tensor& self, @@ -74,14 +74,14 @@ at::Tensor norm_backward( const at::Tensor& self, const optional& p_, at::Tensor norm, - at::IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim); Tensor norm_jvp( const Tensor& self_p, const Tensor& self_t, const optional& p_, Tensor norm, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim); Tensor norm_jvp( const Tensor& grad, @@ -754,7 +754,7 @@ Tensor amaxamin_jvp( const Tensor& x, const Tensor& dx, const Tensor& result, - IntArrayRef dim, + at::OptionalIntArrayRef dim, bool keepdim); std::tuple layer_norm_double_backward( const Tensor& input, diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 251e1e6f11a2..00b570f4fd23 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -3043,8 +3043,8 @@ class IrParser { { std::array BinaryFloatOp = { - "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", - "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor"}; + "aten::amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor", + "aten::amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor"}; for (auto signature : BinaryFloatOp) { auto ptr_op = getOperatorForLiteral(signature); REGISTER_PARSE_RULE( @@ -4008,11 +4008,11 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { static auto amax_schema = getOperatorForLiteral( - "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor") + "aten::amax(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor") ->schema(); static auto amin_schema = getOperatorForLiteral( - "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor") + "aten::amin(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor") ->schema(); if (node->matches(amax_schema) || node->matches(amin_schema)) { switch (offset) { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 81489e5ebfe1..e86680071a10 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2502,11 +2502,8 @@ def error_inputs_aminmax_amax_amin(op_info, device, **kwargs): # Error Inputs for zero-dim tensors, when 'dim' arg is not provided. shape = (S, 0, S) - err_msg_amax_amin = "reduction" err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity" - if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: - yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin) - elif op_info.name in ['aminmax']: + if op_info.name in ['aminmax']: yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax) # Error Inputs for tensors with more than 64 dimension