mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 23:44:53 +08:00
A follow-up of #165750 to clean up C casts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165891 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
1774 lines
64 KiB
C++
1774 lines
64 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#define TORCH_ASSERT_NO_OPERATORS
|
|
#include <ATen/TensorIterator.h>
|
|
#undef TORCH_ASSERT_NO_OPERATORS
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
|
#include <ATen/ExpandUtils.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/native/TypeProperties.h>
|
|
#include <ATen/MemoryOverlap.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <ATen/NamedTensorUtils.h>
|
|
#include <ATen/TensorOperators.h>
|
|
#include <ATen/TensorIteratorInternal.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/empty_strided.h>
|
|
#endif
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <c10/util/SmallBuffer.h>
|
|
|
|
#include <array>
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
|
|
namespace at {
|
|
|
|
using DimMask = TensorIteratorBase::DimMask;
|
|
using PtrVector = TensorIteratorBase::PtrVector;
|
|
using loop2d_t = TensorIteratorBase::loop2d_t;
|
|
using StrideVector = TensorIteratorBase::StrideVector;
|
|
|
|
namespace {
|
|
|
|
inline void get_base_ptrs(char** ptrs, ArrayRef<OperandInfo> operands) {
|
|
std::transform(operands.begin(), operands.end(), ptrs, [](const OperandInfo& op) {
|
|
return static_cast<char*>(op.data);
|
|
});
|
|
}
|
|
|
|
inline void get_strides(int64_t* strides, ArrayRef<OperandInfo> operands, int64_t ndim) {
|
|
for (const auto dim : c10::irange(ndim)) {
|
|
for (const auto arg : c10::irange(operands.size())) {
|
|
*strides++ = operands[arg].stride_bytes[dim];
|
|
}
|
|
}
|
|
// Always at least 2d strides to support 2d for_each loops
|
|
if (ndim < 2) {
|
|
auto ntensors = operands.size();
|
|
std::fill_n(strides, (2 - ndim) * ntensors, 0);
|
|
}
|
|
}
|
|
|
|
OptionalTensorRef make_otr(const TensorBase &tensor) {
|
|
if (tensor.defined()) {
|
|
return OptionalTensorRef(tensor);
|
|
} else {
|
|
return OptionalTensorRef();
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
namespace internal {
|
|
|
|
OpaqueOptionalTensorRef::OpaqueOptionalTensorRef() {
|
|
static_assert(alignof(OptionalTensorRef) == alignof(TensorBase));
|
|
static_assert(sizeof(OptionalTensorRef) == sizeof(TensorBase));
|
|
new (data_.data()) OptionalTensorRef();
|
|
}
|
|
|
|
OpaqueOptionalTensorRef::~OpaqueOptionalTensorRef() {
|
|
get()->~OptionalTensorRef();
|
|
}
|
|
|
|
const Tensor& OpaqueOptionalTensorRef::getTensor() const {
|
|
return get()->getTensorRef();
|
|
}
|
|
|
|
}
|
|
|
|
void OperandInfo::tensor(c10::MaybeOwned<TensorBase> &&tensor) {
|
|
tensor_base_ = std::move(tensor);
|
|
*tensor_storage_ = make_otr(*tensor_base_);
|
|
}
|
|
|
|
void OperandInfo::exchange_tensor(c10::MaybeOwned<TensorBase> &&new_tensor) {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!original_tensor_base_->defined());
|
|
original_tensor_base_ = std::exchange(tensor_base_, std::move(new_tensor));
|
|
*original_tensor_storage_ = std::exchange(*tensor_storage_, make_otr(*tensor_base_));
|
|
}
|
|
|
|
void OperandInfo::restore_original_tensor() {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(original_tensor_base_->defined());
|
|
tensor_base_ = std::move(original_tensor_base_);
|
|
*tensor_storage_ = std::exchange(*original_tensor_storage_, OptionalTensorRef{});
|
|
}
|
|
|
|
/// Construction
|
|
TensorIteratorConfig& TensorIteratorConfig::add_owned_output(const TensorBase& output) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
num_inputs_ == 0,
|
|
"Keep in mind that you have to add all outputs first before adding any input. "
|
|
"For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
|
|
tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, output));
|
|
num_outputs_++;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::add_owned_input(const TensorBase& input) {
|
|
tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, input));
|
|
num_inputs_++;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::add_owned_const_input(const TensorBase& input) {
|
|
const_tensor_indices_.push_back(tensors_.size());
|
|
tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(std::in_place, input));
|
|
num_inputs_++;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::add_borrowed_output(const TensorBase& output) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
num_inputs_ == 0,
|
|
"Keep in mind that you have to add all outputs first before adding any input. "
|
|
"For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
|
|
tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(output));
|
|
num_outputs_++;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase& input) {
|
|
tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
|
|
num_inputs_++;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::add_borrowed_const_input(const TensorBase& input) {
|
|
const_tensor_indices_.push_back(tensors_.size());
|
|
tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
|
|
num_inputs_++;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) {
|
|
TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
|
|
static_dtype_ = dtype;
|
|
static_device_ = device;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype(ScalarType dtype) {
|
|
TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
|
|
static_dtype_ = dtype;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::declare_static_device(Device device) {
|
|
static_device_ = device;
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape) {
|
|
// WARNING:
|
|
// This will bypass all shape checking in the TensorIterator. Kernels which call this method
|
|
// are expected to check shapes before calling `add_owned_input` or `add_owned_output`.
|
|
TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
|
|
static_shape_ = DimVector(shape);
|
|
return *this;
|
|
}
|
|
|
|
TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) {
|
|
declare_static_shape(shape);
|
|
TORCH_INTERNAL_ASSERT(static_shape_.has_value());
|
|
if (static_shape_->empty()) return *this;
|
|
for (const auto& squash_dim : squash_dims) {
|
|
TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast<int64_t>(static_shape_->size()),
|
|
"squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ").");
|
|
(*static_shape_)[squash_dim] = 1;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
bool TensorIteratorConfig::is_tensor_const(size_t idx) {
|
|
return std::find(const_tensor_indices_.begin(), const_tensor_indices_.end(), idx) != const_tensor_indices_.end();
|
|
}
|
|
|
|
// NOTE: [Computing output strides]
|
|
// We use the following algorithm to compute output strides
|
|
// If correctly sized output is provided, we respect its strides and don't change them
|
|
// Otherwise, if provided output is of incorrect size or no output is provided,
|
|
// we try to recover permutation that was applied to the inputs
|
|
// by sorting the strides of the inputs. Precedence is given to the inputs in the order they were added,
|
|
// and to permutations involving non-broadcasted dimensions
|
|
// 1. we loop over inputs starting from the first
|
|
// 2. for all inputs strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. If one
|
|
// of the dimensions being compared has a stride of 0, we move on to the next tensor to determine if
|
|
// these dimensions need to be swapped.
|
|
// 3. strides of dimensions equal to 1 participate in sorting
|
|
// 4. if 2 strides are equal and neither is 0, we try to break the tie by looking at the corresponding dimensions
|
|
// of the tensor. Dimensions were permuted if, when iterating from the end, dimensions corresponding to the
|
|
// same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie.
|
|
//
|
|
// Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly
|
|
// losing the correct permutation of the first tensor if there are permuted trivial dimensions, but could potentially
|
|
// improve traversal order of the second tensor. We chose the former option to better propagate channels last layout
|
|
// for example for a tensor with the sizes N1H1
|
|
// These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all
|
|
// arguments are of the same size) or the argument that is not broadcasted, regardless of its position.
|
|
// As a bonus, it also result in reasonably well-behaved traversal order of the inputs and outputs - in the kernels
|
|
// output is traversed linearly, and since it closely follows input layouts, inputs are traversed linearly as well
|
|
//
|
|
// Examples:
|
|
// full size tensor + broadcasted tensor with 0 or 1 non-trivial dimensions => strides of output are same
|
|
// as strides of full size input regardless of the order
|
|
// 2 tensors of same size but different strides => output strides are the same as first argument
|
|
//
|
|
// We also have fast path for memory-dense inputs with the same strides (or, trivially, single memory-dense input)
|
|
// that outputs a tensor with the same strides as inputs. The only difference in result with the algorithm described
|
|
// above is for strides for trivial (1) dimensions, where in ambiguous cases for performance reasons we default to
|
|
// contiguous strides.
|
|
// Example: tensor with sizes NC11 and strides C1CC will produce output with strides C111 (note differences are only
|
|
// in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost)
|
|
// We might change this behavior in future once performance considerations are resolved
|
|
|
|
void TensorIteratorBase::reorder_dimensions() {
|
|
// Sort the dimensions based on strides in ascending order with reduced dims
|
|
// at the front. NOTE: that this inverts the order of C-contiguous tensors.
|
|
// strides[0] is the fastest moving dimension instead of strides[ndim - 1].
|
|
// See NOTE: [Computing output strides] and inline comments for more detailed description
|
|
|
|
perm_.resize(ndim());
|
|
if (ndim() == 1) {
|
|
perm_[0] = 0;
|
|
return;
|
|
}
|
|
|
|
// initialize perm with n-1, n-2, ..., 1, 0
|
|
std::iota(perm_.rbegin(), perm_.rend(), 0);
|
|
|
|
// Reordering dimensions changes iteration order
|
|
if (enforce_linear_iteration_) {
|
|
permute_dimensions(perm_);
|
|
return;
|
|
}
|
|
|
|
// returns 1 if the dim0 should come after dim1, -1 if dim0 should come
|
|
// before dim1, and 0 if the comparison is ambiguous.
|
|
auto should_swap = [&](size_t dim0, size_t dim1) {
|
|
for (const auto arg : c10::irange(ntensors())) {
|
|
// ignore undefined or incorrectly sized tensors
|
|
if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) {
|
|
continue;
|
|
}
|
|
int64_t stride0 = operands_[arg].stride_bytes[dim0];
|
|
int64_t stride1 = operands_[arg].stride_bytes[dim1];
|
|
if (is_reduction_ && operands_[arg].is_output) {
|
|
// move reduced dimensions to the front
|
|
// strides of reduced dimensions are always set to 0 by review_reduce_result
|
|
if ((stride0 == 0) != (stride1 == 0)) {
|
|
return stride1 == 0 ? 1 : -1;
|
|
}
|
|
}
|
|
//move on to the next input if one of the dimensions is broadcasted
|
|
if (stride0 == 0 || stride1 == 0) {
|
|
continue;
|
|
// it is important to return here only with strict comparisons, for equal strides we try to break the tie later
|
|
// by comparing corresponding dimensions or if that does not work, moving on to the next tensor
|
|
} else if (stride0 < stride1) {
|
|
return -1;
|
|
} else if (stride0 > stride1) {
|
|
return 1;
|
|
} else { //equal strides, use dimensions themselves as the tie-breaker.
|
|
//at this point, with zero strides out of the way, we are guaranteed that operand dimensions are equal to shape_
|
|
auto t_dim0 = shape_[dim0];
|
|
auto t_dim1 = shape_[dim1];
|
|
//return only if dimensions should be swapped, otherwise move on to the next tensor
|
|
if (t_dim0 > t_dim1) {
|
|
return 1;
|
|
}
|
|
}
|
|
}
|
|
return 0;
|
|
};
|
|
|
|
// insertion sort with support for ambiguous comparisons
|
|
for (const auto i : c10::irange(1, ndim())) {
|
|
int dim1 = i;
|
|
for (int dim0 = i - 1; dim0 >= 0; dim0--) {
|
|
int comparison = should_swap(perm_[dim0], perm_[dim1]);
|
|
if (comparison > 0) {
|
|
std::swap(perm_[dim0], perm_[dim1]);
|
|
dim1 = dim0;
|
|
} else if (comparison < 0) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// perform re-ordering of shape and strides
|
|
permute_dimensions(perm_);
|
|
}
|
|
|
|
// Computes a common dtype using type promotion
|
|
// See the [Common Dtype Computation] note
|
|
ScalarType TensorIteratorBase::compute_common_dtype() {
|
|
at::native::ResultTypeState state = {};
|
|
for (const auto& op : operands_) {
|
|
if (op.is_output) {
|
|
continue;
|
|
}
|
|
|
|
state = at::native::update_result_type_state(op.tensor(), state);
|
|
}
|
|
|
|
common_dtype_ = at::native::result_type(state);
|
|
TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined);
|
|
|
|
return common_dtype_;
|
|
}
|
|
|
|
static TensorOptions original_options(const OperandInfo& op) {
|
|
if (op.original_tensor_base().defined()) {
|
|
return op.original_tensor_base().options();
|
|
} else {
|
|
return op.options();
|
|
}
|
|
}
|
|
|
|
// Implements the behavior of the following flags:
|
|
// - check_all_same_dtype_
|
|
// - check_all_same_device_
|
|
// - enforce_safe_casting_to_output_
|
|
// - promote_inputs_to_common_dtype_
|
|
// - cast_common_dtype_to_outputs_
|
|
//
|
|
// See their descriptions in TensorIterator.h for details.
|
|
// NOTE: Checks for more specific behaviors (e.g. the first and second
|
|
// inputs must share a dtype, but the third must have the long dtype)
|
|
// should be implemented directly and outside of TensorIterator.
|
|
void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
|
|
// Reviews operands (1/2)
|
|
// - validates that all input tensors are defined
|
|
// - computes common device
|
|
// - determines if there are undefined outputs
|
|
// - determines if there are different dtypes and attempts
|
|
// to quickly acquire a common dtype
|
|
Device common_device = kCPU;
|
|
common_dtype_ = ScalarType::Undefined;
|
|
// NB: despite output_dtype's generic sounding name, it only is
|
|
// used in a nontrivial way if check_all_same_dtype is true
|
|
ScalarType output_dtype = ScalarType::Undefined;
|
|
bool has_different_input_dtypes = false;
|
|
bool has_different_output_dtypes = false;
|
|
bool has_undefined_outputs = false;
|
|
|
|
for (auto& op : operands_) {
|
|
// Validates that all inputs have type information, and that
|
|
// if an output is missing type information that we can infer
|
|
// the device it should be allocated on.
|
|
if (!op.is_type_defined()) {
|
|
TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!");
|
|
|
|
if (config.static_dtype_.has_value()) {
|
|
op.target_dtype = config.static_dtype_.value();
|
|
} else {
|
|
has_undefined_outputs = true;
|
|
}
|
|
|
|
if (config.static_device_.has_value()) {
|
|
op.device = config.static_device_.value();
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
|
|
}
|
|
|
|
if (has_undefined_outputs || !op.device.has_value()) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Validates input tensors are defined
|
|
if (!op.tensor_base().defined()) {
|
|
TORCH_INTERNAL_ASSERT(op.is_output, "Found undefined input tensor!");
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(op.target_dtype == op.current_dtype)
|
|
|
|
// Acquires the first non-CPU device (if any) as the common device
|
|
if (common_device == kCPU && !op.tensor_base().is_cpu()) {
|
|
common_device = op.tensor_base().device();
|
|
}
|
|
|
|
if (!op.is_output) {
|
|
// Determines if there are varying input dtypes
|
|
// NOTE: the common dtype is set to the first defined input dtype observed
|
|
if (op.target_dtype != common_dtype_) {
|
|
if (common_dtype_ == ScalarType::Undefined) {
|
|
common_dtype_ = op.target_dtype;
|
|
} else {
|
|
has_different_input_dtypes = true;
|
|
}
|
|
}
|
|
} else { // op.is_output
|
|
// Determines if there are varying output dtypes
|
|
// NOTE: the output dtype is set to the first defined output dtype observed
|
|
if (op.target_dtype != output_dtype) {
|
|
if (output_dtype == ScalarType::Undefined) {
|
|
output_dtype = op.target_dtype;
|
|
} else {
|
|
has_different_output_dtypes = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Checks that either the computation type is computable or unneeded
|
|
TORCH_INTERNAL_ASSERT(!(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ &&
|
|
(has_undefined_outputs || config.enforce_safe_casting_to_output_ ||
|
|
config.cast_common_dtype_to_outputs_)));
|
|
|
|
// Checks that all inputs and defined outputs are the same dtype, if requested
|
|
if (config.check_all_same_dtype_ &&
|
|
(has_different_input_dtypes || has_different_output_dtypes ||
|
|
(common_dtype_ != output_dtype && output_dtype != ScalarType::Undefined))) {
|
|
// Throws an informative error message
|
|
for (auto& op : operands_) {
|
|
if (!op.tensor_base().defined()) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_CHECK(op.target_dtype == common_dtype_,
|
|
"Found dtype ", op.target_dtype, " but expected ", common_dtype_);
|
|
}
|
|
}
|
|
|
|
// Short-circuits if no additional work required
|
|
if (!has_undefined_outputs && !config.check_all_same_device_ &&
|
|
!config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ &&
|
|
!config.enforce_safe_casting_to_output_) {
|
|
// Invalidates common_dtype_ if it could not be inferred
|
|
common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_;
|
|
return;
|
|
}
|
|
|
|
// Computes a common dtype, if needed
|
|
if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) {
|
|
common_dtype_ = compute_common_dtype();
|
|
}
|
|
|
|
// Promotes common dtype to the default float scalar type, if needed
|
|
if (config.promote_integer_inputs_to_float_ &&
|
|
c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
|
|
common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
|
|
}
|
|
|
|
// Reviews operands (2/2)
|
|
// - sets metadata for undefined outputs
|
|
// - checks that all tensors are on the same device, if requested
|
|
// - checks that the common dtype can safely cast to each output, if requested
|
|
// - creates temporaries for CPU operations, if needed and requested
|
|
common_device_ = common_device;
|
|
int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0;
|
|
int current_cpu_scalars_on_non_cpu = 0;
|
|
for (auto& op : operands_) {
|
|
bool is_type_defined = op.is_type_defined();
|
|
bool is_device_defined = op.is_device_defined();
|
|
|
|
if (!is_type_defined) {
|
|
op.target_dtype = common_dtype_;
|
|
}
|
|
if (!is_device_defined) {
|
|
op.device = common_device;
|
|
}
|
|
|
|
if (!is_type_defined && !is_device_defined) {
|
|
continue;
|
|
}
|
|
|
|
// Skips undefined tensors
|
|
if (!op.tensor_base().defined()) {
|
|
continue;
|
|
}
|
|
|
|
// Checks all tensors are on the same device, if requested
|
|
if (config.check_all_same_device_) {
|
|
// Handles CPU scalars on CUDA kernels that support them
|
|
if (!common_device.is_cpu() &&
|
|
config.allow_cpu_scalars_ && !op.is_output && op.tensor_base().dim() == 0 &&
|
|
op.tensor_base().is_cpu()) {
|
|
TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
|
|
"Trying to pass too many CPU scalars to non-CPU kernel!");
|
|
++current_cpu_scalars_on_non_cpu;
|
|
} else if (op.device != common_device) {
|
|
TORCH_CHECK(false,
|
|
"Expected all tensors to be on the same device, but "
|
|
"found at least two devices, ", common_device, " and ", op.device, "!");
|
|
}
|
|
}
|
|
|
|
// Checks safe casting, if requested
|
|
if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
|
|
TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
|
|
"result type ", common_dtype_, " can't be cast to the "
|
|
"desired output type ", op.current_dtype);
|
|
}
|
|
|
|
// Creates temporaries for CPU operations, if needed and requested
|
|
// TODO: reuse temporaries when possible (e.g. for inplace operations)
|
|
if (common_device == kCPU) {
|
|
// Casts to outputs by creating temporaries of the correct dtype (if needed)
|
|
// NB: we skip this on is_meta_, because the temporary allocation here is
|
|
// unnecessary if we aren't going to actually do the compute
|
|
if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
|
|
TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
|
|
// Marker [Output original_tensor is set]
|
|
// NB: do NOT use set_output here, as the temporary is NOT a true output;
|
|
// op.tensor is the true output and it was pre-provided for us.
|
|
// TODO: The logic for cast_outputs will need to be handled by the
|
|
// structured kernels implementation. What probably should happen
|
|
// is that we pass in the inferred dtype into the out kernel, and
|
|
// then after calling the out kernel, do the conversion (which
|
|
// is cast_outputs here), but integrating this with existing
|
|
// TensorIterator will take a little doing
|
|
op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(
|
|
at::empty_like(op.tensor(),
|
|
op.tensor_base().options().dtype(common_dtype_),
|
|
LEGACY_CONTIGUOUS_MEMORY_FORMAT)));
|
|
if (!names_.empty()) {
|
|
namedinference::propagate_names(op.tensor_base(), names_);
|
|
}
|
|
op.current_dtype = common_dtype_;
|
|
op.target_dtype = common_dtype_;
|
|
}
|
|
|
|
// Promotes inputs by creating temporaries of the correct dtype
|
|
if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
|
|
op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(op.tensor().to(common_dtype_)));
|
|
op.current_dtype = common_dtype_;
|
|
op.target_dtype = common_dtype_;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
StrideVector TensorIteratorBase::compatible_stride(int64_t element_size) const {
|
|
auto stride = StrideVector();
|
|
int64_t next_stride = element_size;
|
|
for (const auto dim : c10::irange(ndim())) {
|
|
stride.push_back(next_stride);
|
|
next_stride *= shape_[dim];
|
|
}
|
|
return stride;
|
|
}
|
|
|
|
DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const {
|
|
// Invert the permutation caused by reorder_dimensions. This is not valid
|
|
// after coalesce_dimensions is called.
|
|
TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_);
|
|
TORCH_INTERNAL_ASSERT(input.size()==perm_.size());
|
|
auto res = DimVector(input.size()); //no initialization needed, every value in res should be written to.
|
|
for (const auto dim : c10::irange(ndim())) {
|
|
res[perm_[dim]] = input[dim];
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void TensorIteratorBase::allocate_or_resize_outputs() {
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
auto& op = operands_[i];
|
|
if (!op.tensor_base().defined() || op.will_resize) {
|
|
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
|
|
auto element_size = elementSize(op.target_dtype);
|
|
op.stride_bytes = compatible_stride(static_cast<int64_t>(element_size));
|
|
// check if permutation is just an inverted order
|
|
bool inverted = true;
|
|
for (const auto j : c10::irange(ndim())) {
|
|
if (perm_[j] != ndim() - j - 1) {
|
|
inverted = false;
|
|
break;
|
|
}
|
|
}
|
|
auto tensor_shape = invert_perm(shape_);
|
|
if (inverted) {
|
|
// can just return contiguous output
|
|
// it is faster because it avoids allocating 0 size tensor and
|
|
// resizing and restriding it
|
|
set_output_raw_strided(i, tensor_shape, {}, original_options(op), names_);
|
|
} else {
|
|
auto tensor_stride = invert_perm(op.stride_bytes);
|
|
for (const auto dim : c10::irange(ndim())) {
|
|
tensor_stride[dim] /= static_cast<int64_t>(element_size);
|
|
}
|
|
set_output_raw_strided(i, tensor_shape, tensor_stride, original_options(op), names_);
|
|
}
|
|
op.current_dtype = op.target_dtype;
|
|
} else if (op.tensor_base().defined()) {
|
|
// Even if we don't resize, we still need to tell set_output about
|
|
// the output, so that we properly set guard and propagate names
|
|
set_output_raw_strided(i, op.tensor_base().sizes(), {}, original_options(op), names_);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) {
|
|
bool should_infer_names = std::any_of(
|
|
operands_.begin(),
|
|
operands_.end(),
|
|
[](const OperandInfo& op) {
|
|
return op.tensor_base().defined() && op.tensor_base().has_names();
|
|
});
|
|
if (!should_infer_names) {
|
|
return;
|
|
}
|
|
|
|
for (auto& op : operands_) {
|
|
if (!op.tensor_base().defined()) continue;
|
|
// Don't include output tensors if we are resizing, since we will
|
|
// clobber their names in any case. (If the output tensor was
|
|
// also an input tensor, we'll pick it up when it shows up again
|
|
// in operands).
|
|
if (config.resize_outputs_ && op.is_output) continue;
|
|
// perform name inference
|
|
if (names_.empty()) {
|
|
names_ = op.tensor_base().names();
|
|
} else {
|
|
names_ = NameVector(unify_from_right(names_, op.tensor_base().names()));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::coalesce_dimensions() {
|
|
if (ndim() <= 1) {
|
|
return;
|
|
}
|
|
|
|
// We can coalesce two adjacent dimensions if either dim has size 1 or if:
|
|
// shape[n] * stride[n] == stride[n + 1].
|
|
auto can_coalesce = [&](int dim0, int dim1) {
|
|
auto shape0 = shape_[dim0];
|
|
auto shape1 = shape_[dim1];
|
|
if (shape0 == 1 || shape1 == 1) {
|
|
return true;
|
|
}
|
|
for (const auto i : c10::irange(ntensors())) {
|
|
auto& stride = operands_[i].stride_bytes;
|
|
if (shape0 * stride[dim0] != stride[dim1]) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
|
|
// replace each operands stride at dim0 with its stride at dim1
|
|
auto replace_stride = [&](int dim0, int dim1) {
|
|
for (const auto i : c10::irange(ntensors())) {
|
|
auto& stride = operands_[i].stride_bytes;
|
|
stride[dim0] = stride[dim1];
|
|
}
|
|
};
|
|
|
|
int prev_dim = 0;
|
|
for (const auto dim : c10::irange(1, ndim())) {
|
|
if (can_coalesce(prev_dim, dim)) {
|
|
if (shape_[prev_dim] == 1) {
|
|
replace_stride(prev_dim, dim);
|
|
}
|
|
shape_[prev_dim] *= shape_[dim];
|
|
} else {
|
|
prev_dim++;
|
|
if (prev_dim != dim) {
|
|
replace_stride(prev_dim, dim);
|
|
shape_[prev_dim] = shape_[dim];
|
|
}
|
|
}
|
|
}
|
|
|
|
shape_.resize(prev_dim + 1);
|
|
for (const auto i : c10::irange(ntensors())) {
|
|
operands_[i].stride_bytes.resize(ndim());
|
|
}
|
|
has_coalesced_dimensions_ = true;
|
|
}
|
|
|
|
int64_t TensorIteratorBase::numel() const {
|
|
int64_t numel = 1;
|
|
for (int64_t size : shape_) {
|
|
numel *= size;
|
|
}
|
|
return numel;
|
|
}
|
|
|
|
StrideVector TensorIteratorBase::get_dim_strides(int dim) const {
|
|
auto dims = ndim();
|
|
auto inner_strides = StrideVector();
|
|
for (auto& op : operands_) {
|
|
inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
|
|
}
|
|
return inner_strides;
|
|
}
|
|
|
|
SmallVector<char*, 4> TensorIteratorBase::get_base_ptrs() const {
|
|
auto ptrs = SmallVector<char*, 4>(ntensors());
|
|
at::get_base_ptrs(ptrs.data(), operands_);
|
|
return ptrs;
|
|
}
|
|
|
|
bool TensorIteratorBase::is_dim_reduced(int dim) const {
|
|
for (auto& op : operands_) {
|
|
if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void TensorIteratorBase::permute_dimensions(IntArrayRef perm) {
|
|
TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim()));
|
|
|
|
auto reorder = [perm](IntArrayRef data) {
|
|
auto res = DimVector(data.size(), 0);
|
|
for (const auto i : c10::irange(perm.size())) {
|
|
res[i] = data[perm[i]];
|
|
}
|
|
return res;
|
|
};
|
|
|
|
// Update shape and strides
|
|
shape_ = reorder(shape_);
|
|
for (auto& op : operands_) {
|
|
if (!op.stride_bytes.empty()) {
|
|
op.stride_bytes = reorder(op.stride_bytes);
|
|
}
|
|
}
|
|
}
|
|
|
|
int64_t TensorIteratorBase::num_output_elements() const {
|
|
int64_t elem = 1;
|
|
for (const auto dim : c10::irange(ndim())) {
|
|
if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0) {
|
|
elem *= shape_[dim];
|
|
}
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
int TensorIteratorBase::num_reduce_dims() const {
|
|
int count = 0;
|
|
for (const auto dim : c10::irange(ndim())) {
|
|
if (operands_[0].stride_bytes[dim] == 0) {
|
|
count++;
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) {
|
|
int64_t numel = this->numel();
|
|
if (numel == 0) {
|
|
return;
|
|
} else if (numel < grain_size || at::get_num_threads() == 1) {
|
|
serial_for_each(loop, {0, numel});
|
|
return;
|
|
} else {
|
|
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
|
|
serial_for_each(loop, {begin, end});
|
|
});
|
|
}
|
|
}
|
|
|
|
StrideVector TensorIteratorBase::get_strides() const {
|
|
const auto dim = ndim();
|
|
StrideVector strides(static_cast<size_t>(std::max(dim, 2)) * ntensors());
|
|
at::get_strides(strides.data(), operands_, dim);
|
|
return strides;
|
|
}
|
|
|
|
void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const {
|
|
if (range.size() == 0) {
|
|
return;
|
|
}
|
|
|
|
const auto ntensors = this->ntensors();
|
|
const auto ndim = this->ndim();
|
|
|
|
c10::SmallBuffer<char*, 4> ptrs(ntensors);
|
|
c10::SmallBuffer<int64_t, 8> strides(ntensors * static_cast<size_t>(std::max(ndim, 2)));
|
|
|
|
at::get_base_ptrs(ptrs.data(), operands_);
|
|
at::get_strides(strides.data(), operands_, ndim);
|
|
at::internal::serial_for_each(
|
|
shape_, strides, ptrs.data(), ptrs.size(), loop, range);
|
|
}
|
|
|
|
bool TensorIteratorBase::is_trivial_1d() const {
|
|
// TODO: check for casting once it's supported
|
|
return ndim() == 1;
|
|
}
|
|
|
|
bool TensorIteratorBase::is_contiguous() const {
|
|
if (numel() == 1) {
|
|
return true;
|
|
}
|
|
if (ndim() != 1) {
|
|
return false;
|
|
}
|
|
return has_contiguous_first_dim();
|
|
}
|
|
|
|
|
|
bool TensorIteratorBase::is_scalar(int64_t arg) const {
|
|
const auto& stride = operands_[arg].stride_bytes;
|
|
for (const auto i : c10::irange(ndim())) {
|
|
if (stride[i] != 0 && shape_[i] != 1) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool TensorIteratorBase::is_cpu_scalar(int64_t arg) const {
|
|
return is_scalar(arg) && device(arg).is_cpu();
|
|
}
|
|
|
|
void TensorIteratorBase::cast_outputs() {
|
|
for (auto& op : operands_) {
|
|
if (op.is_output && op.original_tensor_base().defined() &&
|
|
op.original_tensor_base().scalar_type() != op.current_dtype) {
|
|
// TODO: Now that set_output resizes both the original_tensor
|
|
// and tensor, this condition should no longer ever be true
|
|
const auto &original_tensor = op.original_tensor();
|
|
const auto &tensor = op.tensor();
|
|
if (original_tensor.sizes() != tensor.sizes()) {
|
|
original_tensor.resize_as_(tensor).as_strided_(tensor.sizes(), tensor.strides());
|
|
}
|
|
original_tensor.copy_(tensor);
|
|
op.restore_original_tensor();
|
|
}
|
|
}
|
|
}
|
|
|
|
void* TensorIteratorBase::data_ptr(int64_t arg) const {
|
|
return operands_[arg].data;
|
|
}
|
|
|
|
void TensorIteratorBase::remove_operand(int64_t arg) {
|
|
operands_.erase(operands_.begin() + arg);
|
|
}
|
|
|
|
void TensorIteratorBase::unsafe_replace_operand(int64_t arg, void* data) {
|
|
operands_[arg].data = data;
|
|
}
|
|
|
|
void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
|
|
TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1);
|
|
shape_[dim] = size;
|
|
view_offsets_[dim] += start;
|
|
for (auto& op : operands_) {
|
|
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[dim] * start;
|
|
}
|
|
if (size == 1 && !is_reduction_) {
|
|
coalesce_dimensions();
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) {
|
|
TORCH_INTERNAL_ASSERT(start_dim <= ndim());
|
|
for (const auto i : c10::irange(start_dim, ndim())) {
|
|
for (auto& op : operands_) {
|
|
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[i] * indices[i - start_dim];
|
|
}
|
|
shape_[i] = 1;
|
|
}
|
|
}
|
|
|
|
#define BINARY_FLOAT_OP_CONFIG() \
|
|
TensorIteratorConfig() \
|
|
.set_check_mem_overlap(true) \
|
|
.allow_cpu_scalars(true) \
|
|
.promote_inputs_to_common_dtype(true) \
|
|
.cast_common_dtype_to_outputs(true) \
|
|
.enforce_safe_casting_to_output(true) \
|
|
.promote_integer_inputs_to_float(true)
|
|
|
|
// Helper to construct a binary op that promotes integer inputs to float.
|
|
void TensorIteratorBase::build_binary_float_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
build(BINARY_FLOAT_OP_CONFIG()
|
|
.add_owned_output(out)
|
|
.add_owned_const_input(a)
|
|
.add_owned_const_input(b));
|
|
}
|
|
|
|
void TensorIteratorBase::build_borrowing_binary_float_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
build(BINARY_FLOAT_OP_CONFIG()
|
|
.add_output(out)
|
|
.add_const_input(a)
|
|
.add_const_input(b));
|
|
}
|
|
|
|
static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
|
|
config.set_check_mem_overlap(true);
|
|
config.allow_cpu_scalars(true);
|
|
config.promote_inputs_to_common_dtype(true);
|
|
|
|
// When 'out' isn't defined (e.g. for the functional operator 'a == b'), we
|
|
// want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
|
|
// don't coerce the output.
|
|
if (!out.defined()) {
|
|
config.declare_static_dtype(kBool);
|
|
}
|
|
|
|
// Note [special-case bool outputs]
|
|
// We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
|
|
// has `bool` dtype. This is a performance optimization: the functional
|
|
// version of all comparison/logical ops uses a bool output tensor, and we'd like to
|
|
// avoid creating a temporary copy of the output.
|
|
// However, note that all kernels using this TensorIterator will need to special-case when
|
|
// the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
|
|
if (out.defined() && out.scalar_type() != kBool) {
|
|
config.cast_common_dtype_to_outputs(true);
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::build_comparison_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
TensorIteratorConfig config;
|
|
set_up_comparison_op_config(config, out);
|
|
|
|
config.add_owned_output(out);
|
|
config.add_owned_const_input(a);
|
|
config.add_owned_const_input(b);
|
|
build(config);
|
|
}
|
|
|
|
void TensorIteratorBase::build_borrowing_comparison_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
TensorIteratorConfig config;
|
|
set_up_comparison_op_config(config, out);
|
|
|
|
config.add_borrowed_output(out);
|
|
config.add_borrowed_const_input(a);
|
|
config.add_borrowed_const_input(b);
|
|
build(config);
|
|
}
|
|
|
|
void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
TensorIteratorConfig config;
|
|
set_up_comparison_op_config(config, out);
|
|
|
|
config.add_borrowed_output(out);
|
|
config.add_borrowed_const_input(a);
|
|
config.add_owned_const_input(b);
|
|
build(config);
|
|
}
|
|
|
|
void TensorIteratorBase::build_ternary_op(
|
|
const TensorBase& out, const TensorBase& a,
|
|
const TensorBase& b, const TensorBase& c) {
|
|
build(TensorIteratorConfig()
|
|
.promote_inputs_to_common_dtype(true)
|
|
.cast_common_dtype_to_outputs(true)
|
|
.enforce_safe_casting_to_output(true)
|
|
.add_owned_output(out)
|
|
.add_owned_const_input(a)
|
|
.add_owned_const_input(b)
|
|
.add_owned_const_input(c));
|
|
}
|
|
|
|
// This cannot be a function because TensorIteratorConfig is not
|
|
// copyable or movable, so it can't be returned from the function.
|
|
#define BINARY_OP_CONFIG() \
|
|
TensorIteratorConfig() \
|
|
.set_check_mem_overlap(true) \
|
|
.allow_cpu_scalars(true) \
|
|
.promote_inputs_to_common_dtype(true) \
|
|
.cast_common_dtype_to_outputs(true) \
|
|
.enforce_safe_casting_to_output(true) \
|
|
|
|
void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
build(BINARY_OP_CONFIG()
|
|
.add_owned_output(out)
|
|
.add_owned_const_input(a)
|
|
.add_owned_const_input(b));
|
|
}
|
|
|
|
void TensorIteratorBase::build_borrowing_binary_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
build(BINARY_OP_CONFIG()
|
|
.add_output(out)
|
|
.add_const_input(a)
|
|
.add_const_input(b));
|
|
}
|
|
|
|
// This cannot be a function because TensorIteratorConfig is not
|
|
// copyable or movable, so it can't be returned from the function.
|
|
#define UNARY_FLOAT_OP_CONFIG() \
|
|
TensorIteratorConfig() \
|
|
.set_check_mem_overlap(true) \
|
|
.promote_inputs_to_common_dtype(true) \
|
|
.cast_common_dtype_to_outputs(true) \
|
|
.enforce_safe_casting_to_output(true) \
|
|
.promote_integer_inputs_to_float(true)
|
|
|
|
void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
|
build(UNARY_FLOAT_OP_CONFIG()
|
|
.add_owned_output(out)
|
|
.add_owned_const_input(a));
|
|
}
|
|
|
|
void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
|
|
build(UNARY_FLOAT_OP_CONFIG()
|
|
.add_output(out)
|
|
.add_const_input(a));
|
|
}
|
|
|
|
// This cannot be a function because TensorIteratorConfig is not
|
|
// copyable or movable, so it can't be returned from the function.
|
|
#define UNARY_OP_CONFIG() \
|
|
TensorIteratorConfig() \
|
|
.set_check_mem_overlap(true) \
|
|
.cast_common_dtype_to_outputs(false) \
|
|
.enforce_safe_casting_to_output(false) \
|
|
.check_all_same_dtype(true)
|
|
|
|
void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
|
|
build(UNARY_OP_CONFIG()
|
|
.add_owned_output(out)
|
|
.add_owned_const_input(a));
|
|
}
|
|
|
|
void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
|
|
build(UNARY_OP_CONFIG()
|
|
.add_output(out)
|
|
.add_const_input(a));
|
|
}
|
|
|
|
void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(
|
|
const TensorBase& out,
|
|
const TensorBase& a) {
|
|
build(TensorIteratorConfig()
|
|
.set_check_mem_overlap(true)
|
|
.cast_common_dtype_to_outputs(true)
|
|
.enforce_safe_casting_to_output(true)
|
|
.add_output(out)
|
|
.add_owned_const_input(a));
|
|
}
|
|
|
|
// Helper to construct a unary op that forcibly promotes output to boolean.
|
|
// Only be used when the output tensor must have boolean type.
|
|
void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
|
|
build(TensorIteratorConfig()
|
|
.set_check_mem_overlap(true)
|
|
.check_all_same_dtype(false)
|
|
.declare_static_dtype(at::kBool)
|
|
.declare_static_device(a.device())
|
|
.add_output(out)
|
|
.add_const_input(a));
|
|
}
|
|
|
|
TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
TensorIterator iter;
|
|
iter.build_binary_op(out, a, b);
|
|
return iter;
|
|
}
|
|
|
|
TensorIterator TensorIterator::borrowing_binary_op(
|
|
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
TensorIterator iter;
|
|
iter.build_borrowing_binary_op(out, a, b);
|
|
return iter;
|
|
}
|
|
|
|
TensorIterator TensorIterator::binary_float_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
|
|
TensorIterator iter;
|
|
iter.build_binary_float_op(out, a, b);
|
|
return iter;
|
|
}
|
|
|
|
TensorIterator TensorIterator::comparison_op(TensorBase& out, const TensorBase& a,
|
|
const TensorBase& b) {
|
|
TensorIterator iter;
|
|
iter.build_comparison_op(out, a, b);
|
|
return iter;
|
|
}
|
|
|
|
TensorIterator TensorIterator::unary_op(TensorBase& out, const TensorBase& a) {
|
|
TensorIterator iter;
|
|
iter.build_unary_op(out, a);
|
|
return iter;
|
|
}
|
|
|
|
TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase& a) {
|
|
TensorIterator iter;
|
|
iter.build_unary_float_op(out, a);
|
|
return iter;
|
|
}
|
|
|
|
#define NULLARY_OP_CONFIG() \
|
|
TensorIteratorConfig() \
|
|
.set_check_mem_overlap(true) \
|
|
.check_all_same_dtype(false) \
|
|
/* FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 */ \
|
|
.resize_outputs(false)
|
|
|
|
TensorIterator TensorIterator::nullary_op(TensorBase& out) {
|
|
return NULLARY_OP_CONFIG()
|
|
.add_owned_output(out)
|
|
.build();
|
|
}
|
|
|
|
TensorIterator TensorIterator::borrowing_nullary_op(const TensorBase& out) {
|
|
return NULLARY_OP_CONFIG()
|
|
.add_output(out)
|
|
.build();
|
|
}
|
|
|
|
TensorIterator TensorIterator::reduce_op(TensorBase& out, const TensorBase& a) {
|
|
TORCH_INTERNAL_ASSERT(out.defined());
|
|
return TensorIteratorConfig()
|
|
.set_check_mem_overlap(false)
|
|
.add_owned_output(out)
|
|
.add_owned_const_input(a)
|
|
.resize_outputs(false)
|
|
.is_reduction(true)
|
|
// TODO: not supporting casting to outputs is only really necessary for arg{min,max}
|
|
.promote_inputs_to_common_dtype(true)
|
|
.build();
|
|
}
|
|
|
|
TensorIterator TensorIterator::reduce_op(TensorBase& out1, TensorBase& out2, const TensorBase& a) {
|
|
TORCH_INTERNAL_ASSERT(out1.defined());
|
|
TORCH_INTERNAL_ASSERT(out2.defined());
|
|
TORCH_CHECK(a.device() == out1.device() && out1.device() == out2.device(),
|
|
"reduce_op(): expected input and both outputs to be on same device, but input is on ", a.device(),
|
|
", output1 is on ", out1.device(), " and output2 is on", out2.device());
|
|
TORCH_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has ", out1.dim(),
|
|
" and output2 has ", out2.dim());
|
|
TORCH_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has ", out1.sizes(),
|
|
" and output2 has ", out2.sizes());
|
|
TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
|
|
" and output2 has ", out2.strides());
|
|
return TensorIteratorConfig()
|
|
.set_check_mem_overlap(false)
|
|
.add_owned_output(out1)
|
|
.add_owned_output(out2)
|
|
.add_owned_const_input(a)
|
|
.resize_outputs(false)
|
|
.is_reduction(true)
|
|
.check_all_same_dtype(false)
|
|
.build();
|
|
}
|
|
|
|
void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
|
|
for (const auto idx : c10::irange(config.tensors_.size())) {
|
|
auto& tensor = config.tensors_[idx];
|
|
// If *any* of the arguments is a meta tensor, the overall
|
|
// computation is a meta computation (don't do any work,
|
|
// just compute output information). This aligns with
|
|
// our multiple dispatch semantics.
|
|
if (tensor->is_meta()) {
|
|
is_meta_ = true;
|
|
}
|
|
operands_.emplace_back(std::move(tensor));
|
|
operands_[idx].is_const = config.is_tensor_const(idx);
|
|
}
|
|
num_outputs_ = config.num_outputs_;
|
|
}
|
|
|
|
void TensorIteratorBase::mark_outputs() {
|
|
// TODO: merge this into populate_operands
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
operands_[i].is_output = true;
|
|
const auto& output = tensor(i);
|
|
if (!output.defined()) continue;
|
|
|
|
// check if output is also an input
|
|
for (const auto arg : c10::irange(num_outputs_, ntensors())) {
|
|
const auto& input = tensor(arg);
|
|
if (output.is_same(input)) {
|
|
operands_[i].is_read_write = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) {
|
|
// Outputs cannot be broadcasted. Check that the shape of the outputs matches
|
|
// the inferred shape. There's an exception for write-only tensors to support
|
|
// our legacy behavior that functions with `out=` arguments resize their
|
|
// outputs.
|
|
if (config.static_shape_.has_value()) {
|
|
return;
|
|
}
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
const auto& output = tensor(i);
|
|
if (!output.defined()) {
|
|
operands_[i].will_resize = true;
|
|
}
|
|
if (output.defined() && !output.sizes().equals(shape_)) {
|
|
if (config.resize_outputs_ && !operands_[i].is_read_write) {
|
|
operands_[i].will_resize = true;
|
|
continue;
|
|
}
|
|
// for reduction, output size does not match shape_, as output is reduced size, and shape_ is size of the input
|
|
TORCH_CHECK(is_reduction_, "output with shape ", output.sizes(), " doesn't match the broadcast shape ",
|
|
shape_);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) {
|
|
if (!config.check_mem_overlap_) {
|
|
return;
|
|
}
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
const auto& output = tensor_base(i);
|
|
if (!output.defined()) continue;
|
|
assert_no_internal_overlap(output);
|
|
for (const auto j : c10::irange(num_outputs_, ntensors())) {
|
|
const auto& input = tensor_base(j);
|
|
if (!input.is_same(output)) {
|
|
assert_no_partial_overlap(output, input);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) {
|
|
if (config.static_shape_.has_value()) {
|
|
shape_ = *config.static_shape_;
|
|
return;
|
|
}
|
|
|
|
all_ops_same_shape_ = true;
|
|
bool has_scalars = false;
|
|
bool has_tensors = false;
|
|
for (auto& op : operands_) {
|
|
if (!op.tensor_base().defined()) continue;
|
|
|
|
// For now, don't include output tensors when we're resizing outputs.
|
|
// These shapes don't participate in shape computation.
|
|
// This preserves the legacy behavior where torch.add(..., out=dst) resizes
|
|
// the destination tensor. If the output tensor is also an input, we'll
|
|
// pick it up later in the operands.
|
|
if (config.resize_outputs_ && op.is_output) continue;
|
|
TORCH_CHECK(!op.tensor_base().unsafeGetTensorImpl()->has_symbolic_sizes_strides(),
|
|
"TensorIterator does not support symbolic shapes; please implement this operator in torch/_refs "
|
|
"using the elementwise or reduction helpers (look at backtrace to find out what operator this is)");
|
|
auto shape = op.tensor_base().sizes();
|
|
if (shape.empty()) {
|
|
has_scalars = true;
|
|
} else {
|
|
has_tensors = true;
|
|
}
|
|
if (has_scalars && has_tensors) {
|
|
all_ops_same_shape_ = false;
|
|
}
|
|
if (shape_.empty()) {
|
|
shape_ = shape;
|
|
} else if (!shape.equals(shape_)) {
|
|
all_ops_same_shape_ = false;
|
|
shape_ = infer_size_dimvector(shape_, shape);
|
|
}
|
|
}
|
|
all_ops_are_scalars_ = !has_tensors;
|
|
}
|
|
|
|
void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) {
|
|
for (auto& op : operands_) {
|
|
if (op.tensor_base().defined() && !op.will_resize) {
|
|
IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor_base().sizes();
|
|
auto original_stride = op.tensor_base().strides();
|
|
auto element_size_in_bytes = op.tensor_base().element_size();
|
|
auto offset = ndim() - original_shape.size();
|
|
if (offset > 0)
|
|
op.stride_bytes.resize(ndim(), 0);
|
|
else
|
|
op.stride_bytes.resize(ndim());
|
|
for (const auto i : c10::irange(original_shape.size())) {
|
|
// see NOTE: [Computing output strides]
|
|
if (original_shape[i] == 1 && shape_[offset + i] !=1) {
|
|
op.stride_bytes[offset + i] = 0;
|
|
} else {
|
|
op.stride_bytes[offset + i] = original_stride[i] * element_size_in_bytes;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
bool TensorIteratorBase::can_use_32bit_indexing() const {
|
|
int64_t max_value = std::numeric_limits<int32_t>::max();
|
|
if (numel() > max_value) {
|
|
return false;
|
|
}
|
|
for (auto& op : operands_) {
|
|
int64_t max_offset = 1;
|
|
for (const auto dim : c10::irange(ndim())) {
|
|
max_offset += (shape_[dim] - 1) * op.stride_bytes[dim];
|
|
}
|
|
if (max_offset > max_value) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::unique_ptr<TensorIterator> TensorIteratorBase::split(int dim) {
|
|
TORCH_INTERNAL_ASSERT(dim >= 0 && dim < ndim() && shape()[dim] >= 2);
|
|
auto copy = std::make_unique<TensorIterator>(*this);
|
|
|
|
bool overlaps = is_dim_reduced(dim);
|
|
auto copy_size = shape_[dim] / 2;
|
|
auto this_size = shape_[dim] - copy_size;
|
|
copy->narrow(dim, 0, copy_size);
|
|
copy->final_output_ &= !overlaps;
|
|
this->narrow(dim, copy_size, this_size);
|
|
this->accumulate_ |= overlaps;
|
|
|
|
return copy;
|
|
}
|
|
|
|
|
|
int TensorIteratorBase::get_dim_to_split() const {
|
|
TORCH_INTERNAL_ASSERT(ndim() >= 1);
|
|
int64_t max_extent = -1;
|
|
int dim_to_split = -1;
|
|
for (int dim = ndim() - 1; dim >= 0; dim--) {
|
|
const int64_t size = shape_[dim];
|
|
if (size == 0) {
|
|
continue;
|
|
}
|
|
for (auto& op : operands_) {
|
|
// std::abs is necessary to handle some special cases where we support negative strides
|
|
// see the CUDA backend of at::flip
|
|
const int64_t extent = (size - 1) * std::abs(op.stride_bytes[dim]);
|
|
if (extent > max_extent) {
|
|
max_extent = extent;
|
|
dim_to_split = dim;
|
|
}
|
|
}
|
|
}
|
|
TORCH_INTERNAL_ASSERT(max_extent >= 0);
|
|
return dim_to_split;
|
|
}
|
|
|
|
bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
|
|
// This function tries to do a fast setup to avoid needless reordering of dimensions and tracking output strides
|
|
// Return true if it can do fast setup or false otherwise
|
|
// TODO enable fast handling for reductions
|
|
FastSetupType setup_type = compute_fast_setup_type(config);
|
|
if (setup_type == FastSetupType::NONE) {
|
|
return false;
|
|
}
|
|
|
|
// allocate memory for output, memory format depends on setup_type
|
|
switch (setup_type) {
|
|
case FastSetupType::CONTIGUOUS:
|
|
{
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
auto& op = operands_[i];
|
|
if (!op.tensor_base().defined()) {
|
|
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
|
|
}
|
|
set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::Contiguous), names_);
|
|
}
|
|
break;
|
|
}
|
|
case FastSetupType::CHANNELS_LAST:
|
|
{
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
auto& op = operands_[i];
|
|
if (!op.tensor_base().defined()) {
|
|
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
|
|
}
|
|
set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::ChannelsLast), names_);
|
|
}
|
|
break;
|
|
}
|
|
case FastSetupType::NON_OVERLAPPING_DENSE:
|
|
{
|
|
// find the index of a defined tensor in operands_ start from input tensor
|
|
int i_defined = -1;
|
|
for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) {
|
|
if (tensor(i_defined).defined()) break;
|
|
}
|
|
TORCH_CHECK(i_defined >= 0, "Can not find a defined tensor when fast allocating memory to outputs");
|
|
for (const auto i : c10::irange(num_outputs_)) {
|
|
auto& op = operands_[i];
|
|
if (!op.tensor_base().defined()) {
|
|
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
|
|
}
|
|
set_output_raw_strided(i, shape_, tensor_base(i_defined).strides(), original_options(op), names_);
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", std::to_string((int)setup_type));
|
|
}
|
|
//coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here)
|
|
if (ndim() > 1){
|
|
has_coalesced_dimensions_ = true;
|
|
}
|
|
if (ndim() >= 1) {
|
|
shape_[0] = numel();
|
|
shape_.resize(1);
|
|
}
|
|
for (auto& op : operands_ ) {
|
|
auto element_size_in_bytes = op.tensor_base().element_size();
|
|
op.stride_bytes.resize(ndim());
|
|
if (ndim()>0) {
|
|
op.stride_bytes[0] = element_size_in_bytes;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorConfig& config) {
|
|
if (is_reduction_ || !all_ops_same_shape_) {
|
|
return FastSetupType::NONE;
|
|
}
|
|
|
|
// For linear iteration, only contiguous tensors can be coalesced
|
|
// Fast setup of any other format requires changing iteration order
|
|
if (enforce_linear_iteration_) {
|
|
for (const auto& op : operands_) {
|
|
if (op.tensor_base().defined() && !op.will_resize) {
|
|
auto is_contiguous = op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
|
|
if (!is_contiguous) {
|
|
return FastSetupType::NONE;
|
|
}
|
|
}
|
|
}
|
|
return FastSetupType::CONTIGUOUS;
|
|
}
|
|
|
|
bool is_contiguous = true;
|
|
bool is_channels_last = true;
|
|
bool is_non_overlapping_and_dense = true;
|
|
for (const auto& op : operands_) {
|
|
if (op.tensor_base().defined() && !op.will_resize) {
|
|
is_contiguous &= op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
|
|
is_channels_last &= op.tensor_base().is_contiguous(at::MemoryFormat::ChannelsLast);
|
|
is_non_overlapping_and_dense &= op.tensor_base().is_non_overlapping_and_dense();
|
|
}
|
|
}
|
|
// TODO this leads to ambiguous cases (NC11) to be always treated as contiguous
|
|
if (is_contiguous) {
|
|
return FastSetupType::CONTIGUOUS;
|
|
}
|
|
if (is_channels_last) {
|
|
return FastSetupType::CHANNELS_LAST;
|
|
}
|
|
if (is_non_overlapping_and_dense) {
|
|
int64_t prev = -1;
|
|
// Fast setup is allowed only when all the defined tensors have the same shape and strides,
|
|
// Iterate from back to check input tensors' strides first, then output tensors'.
|
|
for (int64_t i = ntensors() - 1; i >= 0; --i) {
|
|
const auto& op = operands_[i];
|
|
if (op.tensor_base().defined() && !op.will_resize) {
|
|
if (prev < 0) {
|
|
prev = i;
|
|
continue;
|
|
}
|
|
if (!tensor_base(prev).strides().equals(op.tensor_base().strides())) {
|
|
// [Note: stride check for non contiguous tensors in fast setup]
|
|
// We prevent 3 cases doing fast setup here:
|
|
// 1. input tensors have different strides.
|
|
// 2. output tensors won't be resized and have different strides.
|
|
// 3. input tensors have the same strides, but output tensors have different strides with input tensors.
|
|
// We don't allow re-stride output tensors in this case since it is not compatible with
|
|
// numpy. The behavior in numpy is that if the output tensor has same shape as the input
|
|
// tensor but different strides, the strides of output tensor will be preserved, so we do
|
|
// the same in tensor iterator.
|
|
return FastSetupType::NONE;
|
|
}
|
|
}
|
|
}
|
|
return FastSetupType::NON_OVERLAPPING_DENSE;
|
|
}
|
|
return FastSetupType::NONE;
|
|
}
|
|
|
|
void TensorIteratorBase::build(TensorIteratorConfig& config) {
|
|
// populate some persistent configuration fields
|
|
is_reduction_ = config.is_reduction_;
|
|
enforce_linear_iteration_ = config.enforce_linear_iteration_;
|
|
|
|
// fill in operands_ based on configuration
|
|
populate_operands(config);
|
|
// set is_output and is_read_write flags on appropriate tensors
|
|
mark_outputs();
|
|
// Check that the outputs have no internal overlap
|
|
// and do not share memory with inputs.
|
|
compute_mem_overlaps(config);
|
|
// Check that input dimensions are aligned correctly & compute outnames.
|
|
compute_names(config);
|
|
// compute the broadcasted shape
|
|
compute_shape(config);
|
|
// mark outputs for resizing if necessary
|
|
mark_resize_outputs(config);
|
|
// compute the result dtype and device
|
|
compute_types(config);
|
|
// try fast setup output tensor, if failed, fallback to normal setup
|
|
if (!fast_set_up(config)) {
|
|
// compute each tensor's stride after broadcasting
|
|
compute_strides(config);
|
|
// re-order dimensions to improve coalescing
|
|
reorder_dimensions();
|
|
// allocate the output tensor if it's not provided
|
|
allocate_or_resize_outputs();
|
|
// coalesce adjacent dimensions when possible
|
|
if (!is_meta_) coalesce_dimensions();
|
|
}
|
|
|
|
if (is_meta_) return;
|
|
|
|
auto has_storage = true;
|
|
for (auto& op : operands_) {
|
|
has_storage &= op.tensor_base().has_storage();
|
|
}
|
|
auto privateuse1_without_storage =
|
|
common_device_.type() == DeviceType::PrivateUse1 &&
|
|
!has_storage;
|
|
|
|
// XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
|
|
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
|
|
// Extend the condition to MAIA tensors as MAIA tensors also don't have storage.
|
|
if (privateuse1_without_storage ||
|
|
common_device_.type() == DeviceType::XLA ||
|
|
common_device_.type() == DeviceType::IPU ||
|
|
common_device_.type() == DeviceType::Lazy ||
|
|
common_device_.type() == DeviceType::MAIA ||
|
|
common_device_.type() == DeviceType::HPU) return;
|
|
|
|
for (auto& op : operands_) {
|
|
TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
|
|
if (op.is_const) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
op.data = const_cast<void*>(op.tensor_base().const_data_ptr());
|
|
} else {
|
|
op.data = op.tensor_base().mutable_data_ptr();
|
|
}
|
|
}
|
|
|
|
// zero out offsets
|
|
// If the tensor is a scalar, we leave room for it
|
|
// So index translations in reduction can access
|
|
// a valid value for the offset
|
|
int64_t ndim_offsets = (ndim() ? ndim() : 1);
|
|
view_offsets_ = DimVector(ndim_offsets, 0);
|
|
}
|
|
|
|
// This is the structured kernels' implementation of set_output. It is
|
|
// NEVER actually called directly; instead, a subclass of TensorIteratorBase
|
|
// will override set_output to actually do the operation, and then call
|
|
// set_output on the TensorIteratorBase to setup TI's metadata.
|
|
// The precondition for this function is that maybe_get_output() now
|
|
// unconditionally returns a real Tensor (prior to output setting,
|
|
// this function may return an undefined tensor.)
|
|
void TensorIteratorBase::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
|
|
auto& op = operands_[output_idx];
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
|
|
const auto& t = maybe_get_output(output_idx);
|
|
TORCH_INTERNAL_ASSERT(t.defined());
|
|
if (!op.tensor_base().defined()) {
|
|
op.tensor(c10::MaybeOwned<TensorBase>::borrowed(t));
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.target_dtype == t.scalar_type());
|
|
} else if (op.will_resize) {
|
|
if (op.original_tensor_base().defined()) {
|
|
// OK, so this is pretty weird. To understand how we can end up in
|
|
// this situation, first look at Marker [Output original_tensor is set].
|
|
// That is the sole site where original_tensor may be set on an
|
|
// output operand. Essentially, when we are given an explicit output
|
|
// tensor whose dtype doesn't match the computed common dtype from
|
|
// the input operands, we do a switcheroo: we replace the (incorrectly
|
|
// typed) output tensor with a correctly typed, *temporary* tensor,
|
|
// and remember the original tensor in original_tensor (which will
|
|
// then get written back to when we cast_outputs).
|
|
//
|
|
// Now, what if the given output tensor also happened to be zero
|
|
// size (meaning that we will_resize it)? Well, at the call site
|
|
// above, we don't necessarily(*) know what the correct shape should
|
|
// be, so we give the temporary tensor the same shape as the original.
|
|
// At the time of set_output is when we DO know what the correct size
|
|
// is, and the subclass's implementation of set_output in structured class
|
|
// responsible for resizing original_tensor. But we still have this
|
|
// incorrectly sized temporary output which the structured subclass
|
|
// knows nothing about, so we are obligated to also resize it here.
|
|
//
|
|
// This is a slight memory pessimization, because previously
|
|
// original_tensor only got resized at the end of the computation, rather
|
|
// than at the beginning (as happens here). However, the peak memory
|
|
// usage is the same, since you need to materialize both original tensor
|
|
// and temporary tensor to do the copy.
|
|
//
|
|
// (*) Actually, technically, we probably do know what the shape
|
|
// should be, since we do shape computation before dtype computation.
|
|
// So hypothetically we could figure out what the correct shape is
|
|
// at that point in time and directly allocate the temporary at
|
|
// the right size.
|
|
//
|
|
// But a better solution is to delay allocation of temporaries until
|
|
// after TensorIterator builder, waiting until we actually want
|
|
// to do the computation. That would also remove the necessity
|
|
// for the is_meta_ test.
|
|
TORCH_INTERNAL_ASSERT(op.original_tensor_base().is_same(t));
|
|
TORCH_INTERNAL_ASSERT(!op.tensor_base().is_same(t));
|
|
OptionalTensorRef tensor(op.tensor());
|
|
at::native::resize_output(*tensor, sizes);
|
|
auto const memory_format_opt = options.memory_format_opt();
|
|
if (!strides.empty()) {
|
|
TORCH_INTERNAL_ASSERT(!memory_format_opt.has_value());
|
|
tensor->as_strided_(sizes, strides);
|
|
} else{
|
|
if (memory_format_opt.has_value()) {
|
|
tensor->unsafeGetTensorImpl()->empty_tensor_restride(*memory_format_opt);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
op.tensor_base().is_same(t) || op.current_dtype == op.tensor_base().scalar_type());
|
|
// For simplicity, just always update the cached current_type.
|
|
op.current_dtype = op.tensor_base().scalar_type();
|
|
}
|
|
|
|
// This is the "traditional" implementation of set_output. On TensorIterator
|
|
// instances, it is invoked directly from various call sites in this file. No
|
|
// funny business.
|
|
void TensorIterator::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
|
|
// NB: intentionally no superclass call
|
|
auto& op = operands_[output_idx];
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
|
|
if (!op.tensor_base().defined()) {
|
|
if (strides.empty()) {
|
|
op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty(sizes, options)));
|
|
} else {
|
|
op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty_strided(sizes, strides, options)));
|
|
}
|
|
op.current_dtype = op.target_dtype;
|
|
} else if (op.will_resize) {
|
|
at::native::resize_output(op.tensor(), sizes);
|
|
if (!strides.empty()) {
|
|
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
|
|
op.tensor().as_strided_(sizes, strides);
|
|
} else {
|
|
auto const memory_format = options.memory_format_opt();
|
|
if (memory_format.has_value()) {
|
|
op.tensor_base().unsafeGetTensorImpl()->empty_tensor_restride(*memory_format);
|
|
}
|
|
}
|
|
}
|
|
if (!names.empty()) {
|
|
TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
|
|
namedinference::propagate_names(op.tensor_base(), names);
|
|
}
|
|
}
|
|
|
|
// Not actually used by anything (TensorIterator subclass calls
|
|
// its own implementation of set_output which knows exactly where
|
|
// all the outputs are), but we have to provide all pure virtual methods
|
|
// for MetaBase
|
|
const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) {
|
|
return output(output_idx);
|
|
}
|
|
|
|
SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const {
|
|
return SplitUntil32Bit(*this);
|
|
}
|
|
|
|
/// SplitUntil32Bit. Recursively splits an iterator into sub-iterators that
|
|
/// can use 32-bit indexing.
|
|
|
|
SplitUntil32Bit::iterator::iterator(const TensorIteratorBase& iter) {
|
|
vec.emplace_back(new TensorIterator(iter));
|
|
vec.emplace_back(nullptr); // ++ first pops the last element
|
|
++(*this);
|
|
}
|
|
|
|
SplitUntil32Bit::iterator& SplitUntil32Bit::iterator::operator++() {
|
|
vec.pop_back();
|
|
while (!vec.empty() && !vec.back()->can_use_32bit_indexing()) {
|
|
auto& iter = *vec.back();
|
|
auto split_dim = iter.get_dim_to_split();
|
|
vec.emplace_back(iter.split(split_dim));
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
TensorIterator& SplitUntil32Bit::iterator::operator*() const {
|
|
return *vec.back();
|
|
}
|
|
|
|
SplitUntil32Bit::iterator SplitUntil32Bit::begin() const {
|
|
return SplitUntil32Bit::iterator(iter);
|
|
}
|
|
|
|
SplitUntil32Bit::iterator SplitUntil32Bit::end() const {
|
|
return SplitUntil32Bit::iterator();
|
|
}
|
|
|
|
DimCounter::DimCounter(IntArrayRef shape, Range range)
|
|
: shape(shape)
|
|
, range(range)
|
|
, values(shape.size())
|
|
, offset(range.begin) {
|
|
std::fill(values.begin(), values.end(), 0);
|
|
if (range.begin == 0) {
|
|
return;
|
|
}
|
|
|
|
int64_t linear_offset = range.begin;
|
|
auto ndim = values.size();
|
|
for (const auto dim : c10::irange(ndim)) {
|
|
int64_t size = shape[dim];
|
|
if (size > 0) {
|
|
values[dim] = linear_offset % size;
|
|
linear_offset /= size;
|
|
}
|
|
}
|
|
TORCH_INTERNAL_ASSERT(linear_offset == 0);
|
|
}
|
|
|
|
bool DimCounter::is_done() const {
|
|
return offset >= range.end;
|
|
}
|
|
|
|
void DimCounter::increment(const std::array<int64_t, 2>& step) {
|
|
offset += step[0] * step[1];
|
|
auto ndim = values.size();
|
|
int64_t overflow = step[0];
|
|
size_t i = 0;
|
|
if (step[1] != 1) {
|
|
TORCH_INTERNAL_ASSERT(step[0] == shape[0] && values[0] == 0);
|
|
i = 1;
|
|
overflow = step[1];
|
|
}
|
|
for (; i < ndim && overflow > 0; i++) {
|
|
auto size = shape[i];
|
|
auto prev = values[i];
|
|
auto value = prev + overflow;
|
|
if (value >= size) {
|
|
overflow = 1;
|
|
value -= size;
|
|
TORCH_INTERNAL_ASSERT(value < size);
|
|
} else {
|
|
overflow = 0;
|
|
}
|
|
values[i] = static_cast<int64_t>(value);
|
|
}
|
|
TORCH_INTERNAL_ASSERT(overflow == 0 || overflow == 1);
|
|
}
|
|
|
|
std::array<int64_t, 2> DimCounter::max_2d_step() const {
|
|
int64_t step0 = std::min(shape[0] - values[0], range.end - offset);
|
|
int64_t step1 = 1;
|
|
if (step0 == shape[0] && !shape.empty()) {
|
|
step1 = std::min(shape[1] - values[1], (range.end - offset) / shape[0]);
|
|
}
|
|
return {step0, step1};
|
|
}
|
|
|
|
} // namespace at
|