mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add ScalarList overload to _foreach_lerp
(#134482)
Related: - https://github.com/pytorch/pytorch/issues/133367 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134482 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
7624d625c0
commit
6a368b3fc5
@ -411,26 +411,49 @@ FOREACH_POINTWISE_OP_SCALARLIST(addcmul)
|
||||
FOREACH_POINTWISE_OP_TENSOR(addcdiv)
|
||||
FOREACH_POINTWISE_OP_TENSOR(addcmul)
|
||||
|
||||
#define FOREACH_TERNARY_OP(OP) \
|
||||
std::vector<Tensor> foreach_tensor_ternary_##OP##_slow( \
|
||||
TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
|
||||
check_foreach_api_restrictions(tensors1, tensors2, tensors3); \
|
||||
std::vector<Tensor> result; \
|
||||
for (const auto i : c10::irange(tensors1.size())) { \
|
||||
result.emplace_back(tensors1[i].OP(tensors2[i], tensors3[i])); \
|
||||
} \
|
||||
return result; \
|
||||
} \
|
||||
\
|
||||
void foreach_tensor_ternary_##OP##_slow_( \
|
||||
TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
|
||||
check_foreach_api_restrictions(tensors1, tensors2, tensors3); \
|
||||
for (const auto i : c10::irange(tensors1.size())) { \
|
||||
tensors1[i].OP##_(tensors2[i], tensors3[i]); \
|
||||
} \
|
||||
std::vector<Tensor> foreach_tensor_ternary_lerp_slow(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
TensorList tensors3) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
|
||||
std::vector<Tensor> result;
|
||||
for (const auto i : c10::irange(tensors1.size())) {
|
||||
result.emplace_back(tensors1[i].lerp(tensors2[i], tensors3[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
FOREACH_TERNARY_OP(lerp)
|
||||
void foreach_tensor_ternary_lerp_slow_(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
TensorList tensors3) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
|
||||
for (const auto i : c10::irange(tensors1.size())) {
|
||||
tensors1[i].lerp_(tensors2[i], tensors3[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Tensor> foreach_tensor_lerp_scalarlist_kernel_slow(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
at::ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, scalars);
|
||||
std::vector<Tensor> result;
|
||||
for (const auto i : c10::irange(tensors1.size())) {
|
||||
result.emplace_back(tensors1[i].lerp(tensors2[i], scalars[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void foreach_tensor_lerp_scalarlist_kernel_slow_(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
at::ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, scalars);
|
||||
for (const auto i : c10::irange(tensors1.size())) {
|
||||
tensors1[i].lerp_(tensors2[i], scalars[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void foreach_tensor_zero_slow_(TensorList tensors) {
|
||||
check_foreach_api_restrictions(tensors);
|
||||
|
@ -98,6 +98,19 @@ inline void check_foreach_api_restrictions(
|
||||
scalars.size());
|
||||
}
|
||||
|
||||
inline void check_foreach_api_restrictions(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2);
|
||||
TORCH_CHECK(
|
||||
tensors1.size() == scalars.size(),
|
||||
"Tensor list must have same number of elements as scalar list, got ",
|
||||
tensors1.size(),
|
||||
" and ",
|
||||
scalars.size());
|
||||
}
|
||||
|
||||
// Helper function called in check_fast_path_restrictions to check whether all
|
||||
// corresponding tensors (aligning in index across the tensorLists) share the
|
||||
// same device and dtype.
|
||||
|
@ -663,6 +663,63 @@ struct TernaryOpScalarFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct TernaryOpScalarListFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
static_assert(depth == 2 || depth == 3, "");
|
||||
static_assert(depth >= r_args_depth, "");
|
||||
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
const bool all_aligned =
|
||||
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
const opmath_t scalar = tl.scalar_vals[tensor_loc];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for (int64_t i_start = threadIdx.x;
|
||||
i_start * kILP < n && i_start * kILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
load_store(r_args[1], args[1], 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
r_args[0][ii] =
|
||||
op(static_cast<opmath_t>(r_args[0][ii]),
|
||||
static_cast<opmath_t>(r_args[1][ii]),
|
||||
scalar);
|
||||
}
|
||||
// store
|
||||
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * kILP) {
|
||||
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
r_args[0][ii] =
|
||||
op(static_cast<opmath_t>(r_args[0][ii]),
|
||||
static_cast<opmath_t>(r_args[1][ii]),
|
||||
scalar);
|
||||
}
|
||||
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct power_functor {
|
||||
C10_DEVICE T operator()(const T& a, const T& b) const {
|
||||
|
@ -156,4 +156,75 @@ void foreach_tensor_lerp_list_cuda_(
|
||||
weight.to<opmath_t>());
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> foreach_tensor_lerp_scalarlist_cuda(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
at::ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, scalars);
|
||||
if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) {
|
||||
return foreach_tensor_lerp_scalarlist_kernel_slow(
|
||||
tensors1, tensors2, scalars);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> vec_res;
|
||||
vec_res.reserve(tensors1.size());
|
||||
for (const auto& t : tensors1) {
|
||||
vec_res.emplace_back(at::native::empty_like(t));
|
||||
}
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{
|
||||
tensors1.vec(), tensors2.vec(), vec_res};
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
tensors1[0].scalar_type(),
|
||||
"foreach_tensor_lerp_scalarlist_cuda",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
multi_tensor_apply<3, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
TernaryOpScalarListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 2>(),
|
||||
LerpFunctor<opmath_t>());
|
||||
});
|
||||
|
||||
return tensor_lists[2];
|
||||
}
|
||||
|
||||
void foreach_tensor_lerp_scalarlist_cuda_(
|
||||
TensorList tensors1,
|
||||
TensorList tensors2,
|
||||
at::ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors1, tensors2, scalars);
|
||||
if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) {
|
||||
return foreach_tensor_lerp_scalarlist_kernel_slow_(
|
||||
tensors1, tensors2, scalars);
|
||||
}
|
||||
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{
|
||||
tensors1.vec(), tensors2.vec()};
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
tensors1[0].scalar_type(),
|
||||
"foreach_tensor_lerp_scalarlist_cuda_",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
multi_tensor_apply<2, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
TernaryOpScalarListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 0>(),
|
||||
LerpFunctor<opmath_t>());
|
||||
});
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -11105,6 +11105,22 @@
|
||||
CUDA: foreach_tensor_lerp_list_cuda_
|
||||
autogen: _foreach_lerp.Scalar_out
|
||||
|
||||
- func: _foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow
|
||||
CUDA: foreach_tensor_lerp_scalarlist_cuda
|
||||
autogen: _foreach_lerp.ScalarList_out
|
||||
|
||||
- func: _foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow_
|
||||
CUDA: foreach_tensor_lerp_scalarlist_cuda_
|
||||
autogen: _foreach_lerp.ScalarList_out
|
||||
|
||||
- func: _foreach_lgamma(Tensor[] self) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
variants: function
|
||||
|
@ -232,9 +232,12 @@ aten::_foreach_frac_
|
||||
aten::_foreach_lerp.List
|
||||
aten::_foreach_lerp.List_out
|
||||
aten::_foreach_lerp.Scalar
|
||||
aten::_foreach_lerp.ScalarList
|
||||
aten::_foreach_lerp.ScalarList_out
|
||||
aten::_foreach_lerp.Scalar_out
|
||||
aten::_foreach_lerp_.List
|
||||
aten::_foreach_lerp_.Scalar
|
||||
aten::_foreach_lerp_.ScalarList
|
||||
aten::_foreach_lgamma
|
||||
aten::_foreach_lgamma.out
|
||||
aten::_foreach_lgamma_
|
||||
|
@ -1534,6 +1534,8 @@ def check_autodiff_sample(op, sample, dtype, is_inplace):
|
||||
or (isinstance(sample.args[-1], complex))
|
||||
)
|
||||
if rhs_arg_has_complex_number and dtype == torch.float64:
|
||||
if op.name == "_foreach_lerp":
|
||||
return False, "value cannot be converted to type double without overflow"
|
||||
if op.name in (
|
||||
"_foreach_clamp_max",
|
||||
"_foreach_clamp_min",
|
||||
|
@ -749,6 +749,20 @@ def _foreach_lerp_scalar(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(aten._foreach_lerp.ScalarList)
|
||||
def _foreach_lerp_scalarlist(
|
||||
start_tensors: List[torch.Tensor],
|
||||
end_tensors: List[torch.Tensor],
|
||||
scalars: List[torch.types.Number],
|
||||
) -> List[torch.Tensor]:
|
||||
return aten._foreach_add.List(
|
||||
start_tensors,
|
||||
aten._foreach_mul.ScalarList(
|
||||
aten._foreach_sub.List(end_tensors, start_tensors), scalars
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
|
||||
@register_decomposition(aten.miopen_batch_norm)
|
||||
def miopen_batch_norm(
|
||||
|
@ -541,9 +541,7 @@ def _multi_tensor_adafactor(
|
||||
]
|
||||
torch._foreach_mul_(row_means, row_means)
|
||||
torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads])
|
||||
torch._foreach_mul_(device_row_vars, beta2_ts)
|
||||
torch._foreach_mul_(row_means, one_minus_beta2_ts)
|
||||
torch._foreach_add_(device_row_vars, row_means)
|
||||
torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts)
|
||||
del row_means
|
||||
|
||||
# same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
|
||||
@ -552,9 +550,7 @@ def _multi_tensor_adafactor(
|
||||
]
|
||||
torch._foreach_mul_(col_means, col_means)
|
||||
torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads])
|
||||
torch._foreach_mul_(device_col_vars, beta2_ts)
|
||||
torch._foreach_mul_(col_means, one_minus_beta2_ts)
|
||||
torch._foreach_add_(device_col_vars, col_means)
|
||||
torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts)
|
||||
del col_means
|
||||
|
||||
var_estimates = [
|
||||
@ -574,9 +570,7 @@ def _multi_tensor_adafactor(
|
||||
), "variance should be defined when grad is a vector"
|
||||
|
||||
grads_squared = torch._foreach_mul(device_grads, device_grads)
|
||||
torch._foreach_mul_(device_variances, beta2_ts)
|
||||
torch._foreach_mul_(grads_squared, one_minus_beta2_ts)
|
||||
torch._foreach_add_(device_variances, grads_squared)
|
||||
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)
|
||||
del grads_squared
|
||||
|
||||
# avoid writing into variance during update
|
||||
|
@ -11287,7 +11287,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||
foreach_other_op_db: List[ForeachFuncInfo] = [
|
||||
ForeachFuncInfo(
|
||||
"lerp",
|
||||
sample_inputs_func=foreach_inputs_sample_func(3, True, False),
|
||||
sample_inputs_func=foreach_inputs_sample_func(3, True, True),
|
||||
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
|
||||
supports_autograd=True,
|
||||
supports_inplace_autograd=True,
|
||||
@ -11317,8 +11317,30 @@ foreach_other_op_db: List[ForeachFuncInfo] = [
|
||||
"test_dispatch_symbolic_meta_inplace",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=integral_types_and(torch.bool)),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_meta_inplace",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_meta_outplace",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_dispatch_symbolic_meta_inplace_all_strides",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestMeta",
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides",
|
||||
dtypes=integral_types_and(torch.bool),
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
Reference in New Issue
Block a user