mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Revert D23753711: [pytorch][PR] Add foreach APIs for binary ops with ScalarList
Test Plan: revert-hammer
Differential Revision:
D23753711 (71d1b5b0e2)
Original commit changeset: bf3e8c54bc07
fbshipit-source-id: 192692e0d3fff4cade9983db0a1760fedfc9674c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c79d493096
commit
26001a2334
@ -24,26 +24,6 @@ std::vector<Tensor> foreach_tensor_##NAME##_scalar_kernel_slow(TensorList tensor
|
||||
return result; \
|
||||
}
|
||||
|
||||
#define FOREACH_BINARY_OP_SCALARLIST(NAME) \
|
||||
void foreach_tensor_##NAME##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
\
|
||||
for (int i = 0; i < tensors.size(); i++) { \
|
||||
tensors[i].NAME##_(scalars[i]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors, scalars); \
|
||||
std::vector<Tensor> result; \
|
||||
result.reserve(tensors.size()); \
|
||||
for (int i = 0; i < tensors.size(); i++) { \
|
||||
result.emplace_back(tensors[i].NAME(scalars[i])); \
|
||||
} \
|
||||
\
|
||||
return result; \
|
||||
}
|
||||
|
||||
#define FOREACH_BINARY_OP_LIST(NAME) \
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_list_kernel_slow(TensorList tensors1, TensorList tensors2) { \
|
||||
check_foreach_api_restrictions(tensors1, tensors2); \
|
||||
@ -137,10 +117,6 @@ FOREACH_BINARY_OP_SCALAR(add);
|
||||
FOREACH_BINARY_OP_SCALAR(sub);
|
||||
FOREACH_BINARY_OP_SCALAR(mul);
|
||||
FOREACH_BINARY_OP_SCALAR(div);
|
||||
FOREACH_BINARY_OP_SCALARLIST(add);
|
||||
FOREACH_BINARY_OP_SCALARLIST(sub);
|
||||
FOREACH_BINARY_OP_SCALARLIST(mul);
|
||||
FOREACH_BINARY_OP_SCALARLIST(div);
|
||||
FOREACH_BINARY_OP_LIST(mul);
|
||||
FOREACH_BINARY_OP_LIST(div);
|
||||
FOREACH_UNARY_OP(sqrt);
|
||||
|
||||
@ -31,12 +31,6 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
|
||||
}
|
||||
}
|
||||
|
||||
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars) {
|
||||
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
|
||||
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
|
||||
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
|
||||
}
|
||||
|
||||
// To go via 'fast' path, several conditions must be satisfied
|
||||
// - All tensors must be on the same device
|
||||
// - All tensors must have strided layout
|
||||
@ -138,13 +132,5 @@ bool can_use_fast_route(TensorList tensors) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
|
||||
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
|
||||
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
|
||||
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
|
||||
|
||||
return can_use_fast_route(tensors);
|
||||
}
|
||||
|
||||
}
|
||||
}} // at::native
|
||||
|
||||
@ -1,60 +0,0 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
template<template<class> class Op>
|
||||
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
std::vector<at::Tensor> vec_res;
|
||||
for (const auto& t: tensors) {
|
||||
vec_res.emplace_back(at::native::empty_like(t));
|
||||
}
|
||||
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
tensor_lists.emplace_back(vec_res);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
|
||||
multi_tensor_apply<2>(tensor_lists, scalars, BinaryOpScalarListFunctor<scalar_t, Op>());
|
||||
});
|
||||
return tensor_lists[1];
|
||||
}
|
||||
|
||||
template<template<class> class Op>
|
||||
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
|
||||
multi_tensor_apply<1>(tensor_lists, scalars, BinaryOpScalarListFunctor_<scalar_t, Op>());
|
||||
});
|
||||
}
|
||||
|
||||
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
|
||||
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors); \
|
||||
\
|
||||
if (!can_use_fast_route(tensors, scalars)) { \
|
||||
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
foreach_binary_op_<OP>(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<double> scalars) { \
|
||||
check_foreach_api_restrictions(tensors); \
|
||||
\
|
||||
if (!can_use_fast_route(tensors, scalars)) { \
|
||||
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
|
||||
} \
|
||||
\
|
||||
return foreach_binary_op<OP>(tensors, scalars); \
|
||||
}
|
||||
|
||||
FOREACH_BINARY_OP_SCALARLIST(add, std::plus);
|
||||
FOREACH_BINARY_OP_SCALARLIST(sub, std::minus);
|
||||
FOREACH_BINARY_OP_SCALARLIST(mul, std::multiplies);
|
||||
FOREACH_BINARY_OP_SCALARLIST(div, std::divides);
|
||||
|
||||
}} // namespace at::native
|
||||
@ -118,121 +118,6 @@ struct BinaryOpScalarFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template<class> class Op>
|
||||
struct BinaryOpScalarListFunctor_ {
|
||||
__device__ void operator() (
|
||||
int chunk_size,
|
||||
TensorListScalarListMetadata<1>& tl) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T* x = (T*)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
double y = tl.scalar_vals[tensor_loc];
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
T r_x[kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x)) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0 , i_start);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
|
||||
}
|
||||
// store
|
||||
load_store(x, r_x, i_start, 0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_x[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if(i < n && i < chunk_size) {
|
||||
r_x[ii] = x[i];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if(i < n && i < chunk_size)
|
||||
x[i] = r_x[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template<class> class Op>
|
||||
struct BinaryOpScalarListFunctor {
|
||||
__device__ void operator() (
|
||||
int chunk_size,
|
||||
TensorListScalarListMetadata<2>& tl) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T* x = (T*)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
T* out = (T*)tl.addresses[1][tensor_loc];
|
||||
out += chunk_idx * chunk_size;
|
||||
|
||||
double y = tl.scalar_vals[tensor_loc];
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
T r_x[kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0 , i_start);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
|
||||
}
|
||||
// store
|
||||
load_store(out, r_x, i_start, 0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_x[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if(i < n && i < chunk_size) {
|
||||
r_x[ii] = x[i];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if(i < n && i < chunk_size)
|
||||
out[i] = r_x[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template<class> class Op>
|
||||
struct BinaryOpListAlphaFunctor_ {
|
||||
__device__ void operator() (
|
||||
|
||||
@ -26,7 +26,6 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s
|
||||
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
|
||||
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
|
||||
|
||||
template<int n> struct TensorListMetadata
|
||||
{
|
||||
@ -36,15 +35,6 @@ template<int n> struct TensorListMetadata
|
||||
int block_to_chunk[depth_to_max_blocks[n-1]];
|
||||
};
|
||||
|
||||
template<int n> struct TensorListScalarListMetadata
|
||||
{
|
||||
void* addresses[n][depth_to_max_tensors_scalarlist[n-1]];
|
||||
int sizes[depth_to_max_tensors_scalarlist[n-1]];
|
||||
double scalar_vals[depth_to_max_tensors_scalarlist[n-1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n-1]];
|
||||
};
|
||||
|
||||
template<typename T, typename U, typename... ArgTypes>
|
||||
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
||||
__global__ void
|
||||
@ -59,71 +49,11 @@ multi_tensor_apply_kernel(
|
||||
template<int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
at::ArrayRef<double> scalars,
|
||||
T callable,
|
||||
ArgTypes... args) {
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
|
||||
const cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
size_t n_tensors = tensor_lists[0].size();
|
||||
TensorListScalarListMetadata<depth> tensorListMeta;
|
||||
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for(size_t t = 0; t < n_tensors; t++) {
|
||||
|
||||
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t];
|
||||
|
||||
tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++) {
|
||||
tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
}
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks = (tensor_lists[0][t].numel() + kChunkSize - 1)/kChunkSize;
|
||||
for (int chunk = 0; chunk < chunks; chunk++) {
|
||||
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors_scalarlist[depth-1] &&
|
||||
chunk == chunks - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
|
||||
bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1);
|
||||
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
tensorListMeta,
|
||||
callable,
|
||||
args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset.
|
||||
loc_block_info = 0;
|
||||
if(chunk == chunks - 1) {
|
||||
loc_tensor_info = 0;
|
||||
}
|
||||
else {
|
||||
tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1];
|
||||
tensorListMeta.scalar_vals[0] = tensorListMeta.scalar_vals[loc_tensor_info-1];
|
||||
for(int d = 0; d < depth; d++) {
|
||||
tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
|
||||
}
|
||||
loc_tensor_info = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args) {
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
|
||||
const cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
size_t n_tensors = tensor_lists[0].size();
|
||||
TensorListMetadata<depth> tensorListMeta;
|
||||
|
||||
|
||||
@ -6187,7 +6187,6 @@
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda
|
||||
|
||||
- func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6195,7 +6194,6 @@
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda_
|
||||
|
||||
- func: _foreach_sub.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6203,7 +6201,6 @@
|
||||
CUDA: foreach_tensor_sub_scalar_kernel_cuda
|
||||
|
||||
- func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6211,7 +6208,6 @@
|
||||
CUDA: foreach_tensor_sub_scalar_kernel_cuda_
|
||||
|
||||
- func: _foreach_mul.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6219,7 +6215,6 @@
|
||||
CUDA: foreach_tensor_mul_scalar_kernel_cuda
|
||||
|
||||
- func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6227,7 +6222,6 @@
|
||||
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
|
||||
|
||||
- func: _foreach_div.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6235,39 +6229,34 @@
|
||||
CUDA: foreach_tensor_div_scalar_kernel_cuda
|
||||
|
||||
- func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_div_scalar_kernel_cuda_
|
||||
|
||||
- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
- func: _foreach_add.List(Tensor[] tensors1, Tensor[] tensors2, Scalar alpha=1) -> Tensor[]
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_list_kernel_slow
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda
|
||||
|
||||
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
|
||||
use_c10_dispatcher: full
|
||||
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, Scalar alpha=1) -> ()
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_list_kernel_slow_
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda_
|
||||
|
||||
- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, *, Scalar alpha=1) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
- func: _foreach_sub.List(Tensor[] tensors1, Tensor[] tensors2, Scalar alpha=1) -> Tensor[]
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_list_kernel_slow
|
||||
CUDA: foreach_tensor_sub_list_kernel_cuda
|
||||
|
||||
- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
|
||||
use_c10_dispatcher: full
|
||||
- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, Scalar alpha=1) -> ()
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6275,7 +6264,6 @@
|
||||
CUDA: foreach_tensor_sub_list_kernel_cuda_
|
||||
|
||||
- func: _foreach_mul.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6283,15 +6271,13 @@
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda
|
||||
|
||||
- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_list_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda_
|
||||
|
||||
- func: _foreach_div.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
- func: _foreach_div.List(Tensor(a!)[] self, Tensor[] other) -> Tensor[]
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6299,79 +6285,13 @@
|
||||
CUDA: foreach_tensor_div_list_kernel_cuda
|
||||
|
||||
- func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_list_kernel_slow_
|
||||
CUDA: foreach_tensor_div_list_kernel_cuda_
|
||||
|
||||
- func: _foreach_add.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_add_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_add_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_add_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_sub.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_sub_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_div.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_div_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_div_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_div_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_mul.ScalarList(Tensor[] tensors, float[] scalars) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda
|
||||
|
||||
- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, float[] scalars) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: foreach_tensor_mul_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_
|
||||
|
||||
- func: _foreach_exp(Tensor[] tensors) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6379,7 +6299,6 @@
|
||||
CUDA: foreach_tensor_exp_cuda
|
||||
|
||||
- func: _foreach_exp_(Tensor(a!)[] self) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6387,7 +6306,6 @@
|
||||
CUDA: foreach_tensor_exp_cuda_
|
||||
|
||||
- func: _foreach_sqrt(Tensor[] tensors) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6395,7 +6313,6 @@
|
||||
CUDA: foreach_tensor_sqrt_cuda
|
||||
|
||||
- func: _foreach_sqrt_(Tensor(a!)[] self) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6403,7 +6320,6 @@
|
||||
CUDA: foreach_tensor_sqrt_cuda_
|
||||
|
||||
- func: _foreach_addcdiv_(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6411,7 +6327,6 @@
|
||||
CUDA: foreach_tensor_addcdiv_cuda_
|
||||
|
||||
- func: _foreach_addcmul_(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6419,7 +6334,6 @@
|
||||
CUDA: foreach_tensor_addcmul_cuda_
|
||||
|
||||
- func: _foreach_addcdiv(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -6427,7 +6341,6 @@
|
||||
CUDA: foreach_tensor_addcdiv_cuda
|
||||
|
||||
- func: _foreach_addcmul(Tensor[] input, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
|
||||
use_c10_dispatcher: full
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
|
||||
@ -99,10 +99,6 @@ allow_list = [
|
||||
("preprocess", datetime.date(2020, 10, 1)),
|
||||
("compile", datetime.date(2020, 10, 1)),
|
||||
("execute", datetime.date(2020, 10, 1)),
|
||||
("aten::_foreach_add", datetime.date(2020, 10, 1)),
|
||||
("aten::_foreach_sub_", datetime.date(2020, 10, 1)),
|
||||
("aten::_foreach_div", datetime.date(2020, 10, 1)),
|
||||
("aten::_foreach_sub", datetime.date(2020, 10, 1)),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -4,30 +4,21 @@ from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, skipCUDAIfRocm
|
||||
|
||||
class TestForeach(TestCase):
|
||||
foreach_bin_ops = [
|
||||
bin_ops = [
|
||||
torch._foreach_add,
|
||||
torch._foreach_sub,
|
||||
torch._foreach_mul,
|
||||
torch._foreach_div,
|
||||
]
|
||||
|
||||
foreach_bin_ops_ = [
|
||||
torch._foreach_add_,
|
||||
torch._foreach_sub,
|
||||
torch._foreach_sub_,
|
||||
torch._foreach_mul,
|
||||
torch._foreach_mul_,
|
||||
torch._foreach_div,
|
||||
torch._foreach_div_,
|
||||
]
|
||||
|
||||
torch_bin_ops = [
|
||||
torch.add,
|
||||
torch.sub,
|
||||
torch.mul,
|
||||
torch.div,
|
||||
]
|
||||
|
||||
def _get_test_data(self, device, dtype, N):
|
||||
if dtype in [torch.bfloat16, torch.bool, torch.float16]:
|
||||
tensors = [torch.randn(N, N, device=device).to(dtype) for _ in range(N)]
|
||||
|
||||
elif dtype in torch.testing.get_all_int_dtypes():
|
||||
tensors = [torch.randint(1, 100, (N, N), device=device, dtype=dtype) for _ in range(N)]
|
||||
else:
|
||||
@ -35,8 +26,7 @@ class TestForeach(TestCase):
|
||||
|
||||
return tensors
|
||||
|
||||
def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op):
|
||||
for N in [30, 300]:
|
||||
def _test_bin_op_list(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20):
|
||||
tensors1 = self._get_test_data(device, dtype, N)
|
||||
tensors2 = self._get_test_data(device, dtype, N)
|
||||
|
||||
@ -44,10 +34,9 @@ class TestForeach(TestCase):
|
||||
res = foreach_op(tensors1, tensors2)
|
||||
foreach_op_(tensors1, tensors2)
|
||||
self.assertEqual(res, tensors1)
|
||||
self.assertEqual(tensors1, res)
|
||||
self.assertEqual(tensors1, expected)
|
||||
|
||||
def _test_unary_op(self, device, dtype, foreach_op, foreach_op_, torch_op):
|
||||
for N in [30, 300]:
|
||||
def _test_unary_op(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20):
|
||||
tensors1 = self._get_test_data(device, dtype, N)
|
||||
expected = [torch_op(tensors1[i]) for i in range(N)]
|
||||
res = foreach_op(tensors1)
|
||||
@ -55,8 +44,7 @@ class TestForeach(TestCase):
|
||||
self.assertEqual(res, tensors1)
|
||||
self.assertEqual(tensors1, expected)
|
||||
|
||||
def _test_pointwise_op(self, device, dtype, foreach_op, foreach_op_, torch_op):
|
||||
for N in [30, 300]:
|
||||
def _test_pointwise_op(self, device, dtype, foreach_op, foreach_op_, torch_op, N=20):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
tensors1 = self._get_test_data(device, dtype, N)
|
||||
tensors2 = self._get_test_data(device, dtype, N)
|
||||
@ -75,8 +63,8 @@ class TestForeach(TestCase):
|
||||
alpha = 2
|
||||
|
||||
expected = [torch_op(tensors1[i], torch.mul(tensors2[i], alpha)) for i in range(N)]
|
||||
res = foreach_op(tensors1, tensors2, alpha=alpha)
|
||||
foreach_op_(tensors1, tensors2, alpha=alpha)
|
||||
res = foreach_op(tensors1, tensors2, alpha)
|
||||
foreach_op_(tensors1, tensors2, alpha)
|
||||
self.assertEqual(res, tensors1)
|
||||
|
||||
if dtype == torch.bool:
|
||||
@ -100,7 +88,7 @@ class TestForeach(TestCase):
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
|
||||
def test_addcmul(self, device, dtype):
|
||||
if self.device_type == 'cpu':
|
||||
if device == 'cpu':
|
||||
if dtype == torch.half:
|
||||
with self.assertRaisesRegex(RuntimeError, r"\"addcmul_cpu_out\" not implemented for \'Half\'"):
|
||||
self._test_pointwise_op(device, dtype, torch._foreach_addcmul,
|
||||
@ -117,7 +105,7 @@ class TestForeach(TestCase):
|
||||
self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv)
|
||||
return
|
||||
|
||||
if self.device_type == 'cpu':
|
||||
if device == 'cpu':
|
||||
if dtype == torch.half:
|
||||
with self.assertRaisesRegex(RuntimeError, r"\"addcdiv_cpu_out\" not implemented for \'Half\'"):
|
||||
self._test_pointwise_op(device, dtype, torch._foreach_addcdiv,
|
||||
@ -130,371 +118,82 @@ class TestForeach(TestCase):
|
||||
#
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_int_scalar(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = 3
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
int_scalar = 1
|
||||
|
||||
# bool tensor + 1 will result in int64 tensor
|
||||
if dtype == torch.bool:
|
||||
expected = [torch.ones(10, 10, device=device, dtype=torch.int64) for _ in range(10)]
|
||||
else:
|
||||
expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
|
||||
res = torch._foreach_add(tensors, int_scalar)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
return
|
||||
|
||||
|
||||
if foreach_bin_op_ == torch._foreach_div_ and dtype in torch.testing.integral_types() and self.device_type == "cpu":
|
||||
if dtype in [torch.bool]:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
return
|
||||
|
||||
# TODO[type promotion]: Fix once type promotion is enabled.
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, [e.to(dtype) for e in expected])
|
||||
"result type Long can't be cast to the desired output type Bool"):
|
||||
torch._foreach_add_(tensors, int_scalar)
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, expected)
|
||||
|
||||
# TODO[Fix scalar list]:
|
||||
# We need to update codegen to correctly handle function overloads with float[] and int[].
|
||||
# As optimizers work with float tensors, the result will always be torch.float32 for now.
|
||||
# Current schema is using 'float[]' as scalar list type.
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_int_scalarlist(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [1 for _ in range(N)]
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
|
||||
# we dont support bool and complex types on CUDA for now
|
||||
if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
return
|
||||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
if dtype == torch.bool:
|
||||
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
|
||||
if dtype in torch.testing.integral_types():
|
||||
if self.device_type == 'cpu':
|
||||
self.assertEqual(res, [e.to(torch.float32) for e in expected])
|
||||
else:
|
||||
# TODO[type promotion]: Fix once type promotion is enabled.
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cpu':
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
torch._foreach_add_(tensors, int_scalar)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_float_scalar(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = 3.3
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
float_scalar = 1.
|
||||
|
||||
if dtype == torch.bool:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
return
|
||||
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types():
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
return
|
||||
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, expected)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_float_scalarlist(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [1.1 for _ in range(N)]
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
|
||||
# we dont support bool and complex types on CUDA for now
|
||||
if (dtype in torch.testing.get_all_complex_dtypes() or dtype == torch.bool) and self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
return
|
||||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
if dtype == torch.bool:
|
||||
# see TODO[Fix scalar list]
|
||||
self.assertEqual(res, [torch_bin_op(t.to(torch.float32), s) for t, s in zip(tensors, scalars)])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
|
||||
# see TODO[Fix scalar list]
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(tensors, res)
|
||||
return
|
||||
# float scalar + integral tensor will result in float tensor
|
||||
if dtype in [torch.uint8, torch.int8, torch.int16,
|
||||
torch.int32, torch.int64, torch.bool]:
|
||||
expected = [torch.ones(10, 10, device=device, dtype=torch.float32) for _ in range(10)]
|
||||
else:
|
||||
expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
|
||||
res = torch._foreach_add(tensors, float_scalar)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types() and self.device_type == "cpu":
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(tensors, expected)
|
||||
if dtype in [torch.uint8, torch.int8, torch.int16,
|
||||
torch.int32, torch.int64, torch.bool]:
|
||||
self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, float_scalar))
|
||||
else:
|
||||
torch._foreach_add_(tensors, float_scalar)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_complex_scalar(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = 3 + 5j
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
complex_scalar = 3 + 5j
|
||||
|
||||
if dtype == torch.bool:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
# bool tensor + 1 will result in int64 tensor
|
||||
expected = [torch.add(complex_scalar, torch.zeros(10, 10, device=device, dtype=dtype)) for _ in range(10)]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
if dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] and device == 'cuda:0':
|
||||
# value cannot be converted to dtype without overflow:
|
||||
self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar))
|
||||
self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, complex_scalar))
|
||||
return
|
||||
|
||||
if dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=True) and \
|
||||
self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
return
|
||||
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
res = torch._foreach_add(tensors, complex_scalar)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype not in [torch.complex64, torch.complex128]:
|
||||
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar))
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
torch._foreach_add_(tensors, complex_scalar)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_complex_scalarlist(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [3 + 5j for _ in range(N)]
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
|
||||
if dtype == torch.bool:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with two bool"):
|
||||
foreach_bin_op(tensors, scalar)
|
||||
return
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_bool_scalar(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalar = True
|
||||
tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
bool_scalar = True
|
||||
|
||||
if dtype == torch.bool:
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)]
|
||||
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, res)
|
||||
return
|
||||
|
||||
if foreach_bin_op == torch._foreach_sub and self.device_type == "cpu":
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"):
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator"):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
elif foreach_bin_op == torch._foreach_sub and self.device_type == 'cuda':
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
self.assertEqual(res, foreach_bin_op(tensors, 1))
|
||||
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, res)
|
||||
else:
|
||||
expected = [torch_bin_op(t, scalar) for t in tensors]
|
||||
res = foreach_bin_op(tensors, scalar)
|
||||
|
||||
# TODO[type promotion]: Fix once type promotion is enabled.
|
||||
if dtype in torch.testing.integral_types() and self.device_type == 'cuda':
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
else:
|
||||
res = torch._foreach_add(tensors, bool_scalar)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types():
|
||||
if foreach_bin_op == torch._foreach_div and self.device_type == "cpu":
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired "):
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, res)
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalar)
|
||||
self.assertEqual(tensors, expected)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_bool_scalarlist(self, device, dtype):
|
||||
for N in [30, 300]:
|
||||
for foreach_bin_op, foreach_bin_op_, torch_bin_op in zip(self.foreach_bin_ops,
|
||||
self.foreach_bin_ops_,
|
||||
self.torch_bin_ops):
|
||||
tensors = self._get_test_data(device, dtype, N)
|
||||
scalars = [True for _ in range(N)]
|
||||
|
||||
if dtype == torch.bool:
|
||||
if self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
return
|
||||
else:
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Subtraction, the `-` operator, with a bool tensor"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
for r in res:
|
||||
self.assertTrue(r.dtype == torch.float32)
|
||||
else:
|
||||
# we dont support bool and complex types on CUDA for now
|
||||
if (dtype in torch.testing.get_all_complex_dtypes()) and self.device_type == 'cuda':
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented for"):
|
||||
foreach_bin_op(tensors, scalars)
|
||||
return
|
||||
|
||||
if foreach_bin_op == torch._foreach_sub:
|
||||
if self.device_type == "cpu":
|
||||
# see TODO[Fix scalar list]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [r.to(torch.float32) for r in foreach_bin_op(tensors, 1)])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "esult type Float can't be cast to the "):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
self.assertEqual(res, foreach_bin_op(tensors, 1))
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
else:
|
||||
# see TODO[Fix scalar list]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [r.to(dtype) for r in foreach_bin_op(tensors, 1)])
|
||||
else:
|
||||
self.assertEqual(res, foreach_bin_op(tensors, 1))
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(res, tensors)
|
||||
else:
|
||||
if self.device_type == "cpu":
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
# see TODO[Fix scalar list]
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [e.to(torch.float32) for e in expected])
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
if dtype in torch.testing.integral_types():
|
||||
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired "):
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
else:
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
self.assertEqual(tensors, expected)
|
||||
else:
|
||||
expected = [torch_bin_op(t, s) for t, s in zip(tensors, scalars)]
|
||||
res = foreach_bin_op(tensors, scalars)
|
||||
|
||||
if dtype in torch.testing.integral_types():
|
||||
self.assertEqual(res, [e.to(dtype) for e in expected])
|
||||
else:
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
foreach_bin_op_(tensors, scalars)
|
||||
torch._foreach_add_(tensors, bool_scalar)
|
||||
self.assertEqual(res, tensors)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
@ -549,9 +248,9 @@ class TestForeach(TestCase):
|
||||
|
||||
# One empty list
|
||||
tensors1.append(torch.tensor([1], device=device))
|
||||
with self.assertRaisesRegex(RuntimeError, "Scalars list must have at least one value."):
|
||||
with self.assertRaisesRegex(RuntimeError, "Tensor list must have at least one tensor."):
|
||||
torch._foreach_add(tensors1, tensors2)
|
||||
with self.assertRaisesRegex(RuntimeError, "Scalars list must have at least one value."):
|
||||
with self.assertRaisesRegex(RuntimeError, "Tensor list must have at least one tensor."):
|
||||
torch._foreach_add_(tensors1, tensors2)
|
||||
|
||||
# Lists have different amount of tensors
|
||||
@ -619,25 +318,13 @@ class TestForeach(TestCase):
|
||||
self.skipTest("Skipped! See https://github.com/pytorch/pytorch/issues/44489")
|
||||
return
|
||||
|
||||
for N in [30, 300]:
|
||||
tensors1 = self._get_test_data(device, dtype, N)
|
||||
|
||||
if dtype in [torch.bfloat16, torch.bool, torch.float16]:
|
||||
tensors2 = [torch.zeros(N, N, device=device, dtype=dtype).add(2) for _ in range(N)]
|
||||
else:
|
||||
tensors2 = self._get_test_data(device, dtype, N)
|
||||
|
||||
expected = [torch.div(tensors1[i], tensors2[i]) for i in range(N)]
|
||||
res = torch._foreach_div(tensors1, tensors2)
|
||||
torch._foreach_div_(tensors1, tensors2)
|
||||
self.assertEqual(res, tensors1)
|
||||
self.assertEqual(tensors1, res)
|
||||
self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div)
|
||||
|
||||
def test_bin_op_list_error_cases(self, device):
|
||||
tensors1 = []
|
||||
tensors2 = []
|
||||
|
||||
for bin_op in self.foreach_bin_ops + self.foreach_bin_ops_:
|
||||
for bin_op in self.bin_ops:
|
||||
# Empty lists
|
||||
with self.assertRaises(RuntimeError):
|
||||
bin_op(tensors1, tensors2)
|
||||
|
||||
@ -58,7 +58,7 @@ class TestNativeFunctions(TestCase):
|
||||
self.do_test_optional_floatlist_with_module(fake_module)
|
||||
|
||||
def test_optional_floatlist_invalid(self):
|
||||
with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"):
|
||||
with self.assertRaisesRegex(TypeError, "must be .* but found"):
|
||||
FloatListWrapperModule()(torch.zeros(1), ["hi"])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
||||
|
||||
@ -281,7 +281,6 @@ UNPACK_METHODS = {
|
||||
'c10::optional<bool>': 'toBoolOptional',
|
||||
'c10::optional<double>': 'toDoubleOptional',
|
||||
'c10::optional<ArrayRef<double>>': 'doublelistOptional',
|
||||
'ArrayRef<double>': 'doublelist',
|
||||
'IntArrayRef': 'intlist',
|
||||
'Scalar': 'scalar',
|
||||
'ScalarType': 'scalartype',
|
||||
|
||||
@ -44,7 +44,6 @@ using at::Generator;
|
||||
using at::TensorList;
|
||||
using at::Dimname;
|
||||
using at::DimnameList;
|
||||
using at::ArrayRef;
|
||||
|
||||
using namespace torch::autograd::utils;
|
||||
|
||||
|
||||
@ -304,10 +304,6 @@ class FunctionSchema:
|
||||
# TODO: fixme
|
||||
if str(self.name) not in [
|
||||
'_amp_non_finite_check_and_unscale_',
|
||||
'_foreach_add_.ScalarList',
|
||||
'_foreach_sub_.ScalarList',
|
||||
'_foreach_mul_.ScalarList',
|
||||
'_foreach_div_.ScalarList',
|
||||
'_foreach_add_.Scalar',
|
||||
'_foreach_sub_.Scalar',
|
||||
'_foreach_mul_.Scalar',
|
||||
|
||||
@ -146,7 +146,6 @@ def type_to_python(typename, size=None):
|
||||
'Dimname': 'Union[str, ellipsis, None]',
|
||||
'DimnameList': 'Sequence[Union[str, ellipsis, None]]',
|
||||
'QScheme': '_qscheme',
|
||||
'ArrayRef<double>' : 'Sequence[float]'
|
||||
}[typename]
|
||||
|
||||
return typename
|
||||
|
||||
@ -366,23 +366,6 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_float_list(PyObject* obj) {
|
||||
auto tuple = six::isTuple(obj);
|
||||
if (!(tuple || PyList_Check(obj))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
||||
if (size > 0) {
|
||||
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
|
||||
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// argnum is needed for raising the TypeError, it's used in the error message.
|
||||
auto FunctionParameter::check(PyObject* obj, std::vector<py::handle> &overloaded_args, int argnum) -> bool
|
||||
{
|
||||
@ -437,9 +420,7 @@ auto FunctionParameter::check(PyObject* obj, std::vector<py::handle> &overloaded
|
||||
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
|
||||
return size > 0 && THPUtils_checkLong(obj);
|
||||
}
|
||||
case ParameterType::FLOAT_LIST: {
|
||||
return is_float_list(obj);
|
||||
}
|
||||
case ParameterType::FLOAT_LIST: return (PyTuple_Check(obj) || PyList_Check(obj));
|
||||
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
||||
case ParameterType::BOOL: return PyBool_Check(obj);
|
||||
case ParameterType::STORAGE: return isStorage(obj);
|
||||
@ -920,7 +901,6 @@ PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject*
|
||||
print_error(self, args, kwargs, parsed_args);
|
||||
}
|
||||
|
||||
|
||||
void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT
|
||||
auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0);
|
||||
std::vector<int> plausible_idxs;
|
||||
|
||||
@ -173,8 +173,6 @@ struct PythonArgs {
|
||||
inline c10::optional<bool> toBoolOptional(int i);
|
||||
inline c10::optional<double> toDoubleOptional(int i);
|
||||
inline c10::OptionalArray<double> doublelistOptional(int i);
|
||||
inline std::vector<double> doublelist(int i);
|
||||
inline std::vector<double> getDoublelist(int i);
|
||||
inline at::Layout layout(int i);
|
||||
inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
||||
inline c10::optional<at::Layout> layoutOptional(int i);
|
||||
@ -371,7 +369,10 @@ inline c10::OptionalArray<int64_t> PythonArgs::intlistOptional(int i) {
|
||||
return intlist(i);
|
||||
}
|
||||
|
||||
inline std::vector<double> PythonArgs::getDoublelist(int i) {
|
||||
inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) {
|
||||
if (!args[i]) {
|
||||
return {};
|
||||
}
|
||||
PyObject* arg = args[i];
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
@ -389,17 +390,6 @@ inline std::vector<double> PythonArgs::getDoublelist(int i) {
|
||||
return res;
|
||||
}
|
||||
|
||||
inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) {
|
||||
if (!args[i]) {
|
||||
return {};
|
||||
}
|
||||
return this->getDoublelist(i);
|
||||
}
|
||||
|
||||
inline std::vector<double> PythonArgs::doublelist(int i) {
|
||||
return this->getDoublelist(i);
|
||||
}
|
||||
|
||||
inline at::ScalarType PythonArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
||||
if (!args[i]) return default_scalartype;
|
||||
return scalartype(i);
|
||||
|
||||
Reference in New Issue
Block a user