Add ScalarList overload to _foreach_lerp (#134482)

Related:
- https://github.com/pytorch/pytorch/issues/133367

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134482
Approved by: https://github.com/janeyx99
This commit is contained in:
Masaki Kozuki
2024-11-12 19:03:38 +00:00
committed by PyTorch MergeBot
parent 7624d625c0
commit 6a368b3fc5
10 changed files with 245 additions and 30 deletions

View File

@ -411,26 +411,49 @@ FOREACH_POINTWISE_OP_SCALARLIST(addcmul)
FOREACH_POINTWISE_OP_TENSOR(addcdiv)
FOREACH_POINTWISE_OP_TENSOR(addcmul)
#define FOREACH_TERNARY_OP(OP) \
std::vector<Tensor> foreach_tensor_ternary_##OP##_slow( \
TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
check_foreach_api_restrictions(tensors1, tensors2, tensors3); \
std::vector<Tensor> result; \
for (const auto i : c10::irange(tensors1.size())) { \
result.emplace_back(tensors1[i].OP(tensors2[i], tensors3[i])); \
} \
return result; \
} \
\
void foreach_tensor_ternary_##OP##_slow_( \
TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
check_foreach_api_restrictions(tensors1, tensors2, tensors3); \
for (const auto i : c10::irange(tensors1.size())) { \
tensors1[i].OP##_(tensors2[i], tensors3[i]); \
} \
std::vector<Tensor> foreach_tensor_ternary_lerp_slow(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
std::vector<Tensor> result;
for (const auto i : c10::irange(tensors1.size())) {
result.emplace_back(tensors1[i].lerp(tensors2[i], tensors3[i]));
}
return result;
}
FOREACH_TERNARY_OP(lerp)
void foreach_tensor_ternary_lerp_slow_(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
for (const auto i : c10::irange(tensors1.size())) {
tensors1[i].lerp_(tensors2[i], tensors3[i]);
}
}
std::vector<Tensor> foreach_tensor_lerp_scalarlist_kernel_slow(
TensorList tensors1,
TensorList tensors2,
at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, scalars);
std::vector<Tensor> result;
for (const auto i : c10::irange(tensors1.size())) {
result.emplace_back(tensors1[i].lerp(tensors2[i], scalars[i]));
}
return result;
}
void foreach_tensor_lerp_scalarlist_kernel_slow_(
TensorList tensors1,
TensorList tensors2,
at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, scalars);
for (const auto i : c10::irange(tensors1.size())) {
tensors1[i].lerp_(tensors2[i], scalars[i]);
}
}
void foreach_tensor_zero_slow_(TensorList tensors) {
check_foreach_api_restrictions(tensors);

View File

@ -98,6 +98,19 @@ inline void check_foreach_api_restrictions(
scalars.size());
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2,
ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2);
TORCH_CHECK(
tensors1.size() == scalars.size(),
"Tensor list must have same number of elements as scalar list, got ",
tensors1.size(),
" and ",
scalars.size());
}
// Helper function called in check_fast_path_restrictions to check whether all
// corresponding tensors (aligning in index across the tensorLists) share the
// same device and dtype.

View File

@ -663,6 +663,63 @@ struct TernaryOpScalarFunctor {
}
};
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
static_assert(depth == 2 || depth == 3, "");
static_assert(depth >= r_args_depth, "");
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];
T* args[depth];
const bool all_aligned =
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
const opmath_t scalar = tl.scalar_vals[tensor_loc];
// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_args[0], args[0], 0, i_start);
load_store(r_args[1], args[1], 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
scalar);
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
scalar);
}
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
}
}
}
};
template <typename T>
struct power_functor {
C10_DEVICE T operator()(const T& a, const T& b) const {

View File

@ -156,4 +156,75 @@ void foreach_tensor_lerp_list_cuda_(
weight.to<opmath_t>());
});
}
std::vector<at::Tensor> foreach_tensor_lerp_scalarlist_cuda(
TensorList tensors1,
TensorList tensors2,
at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, scalars);
if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) {
return foreach_tensor_lerp_scalarlist_kernel_slow(
tensors1, tensors2, scalars);
}
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}
std::vector<std::vector<at::Tensor>> tensor_lists{
tensors1.vec(), tensors2.vec(), vec_res};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
tensors1[0].scalar_type(),
"foreach_tensor_lerp_scalarlist_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<3, opmath_t>(
tensor_lists,
scalars,
TernaryOpScalarListFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2>(),
LerpFunctor<opmath_t>());
});
return tensor_lists[2];
}
void foreach_tensor_lerp_scalarlist_cuda_(
TensorList tensors1,
TensorList tensors2,
at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, scalars);
if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) {
return foreach_tensor_lerp_scalarlist_kernel_slow_(
tensors1, tensors2, scalars);
}
std::vector<std::vector<at::Tensor>> tensor_lists{
tensors1.vec(), tensors2.vec()};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
tensors1[0].scalar_type(),
"foreach_tensor_lerp_scalarlist_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<2, opmath_t>(
tensor_lists,
scalars,
TernaryOpScalarListFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0>(),
LerpFunctor<opmath_t>());
});
}
} // namespace at::native

View File

@ -11105,6 +11105,22 @@
CUDA: foreach_tensor_lerp_list_cuda_
autogen: _foreach_lerp.Scalar_out
- func: _foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices
variants: function
dispatch:
CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow
CUDA: foreach_tensor_lerp_scalarlist_cuda
autogen: _foreach_lerp.ScalarList_out
- func: _foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices
variants: function
dispatch:
CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow_
CUDA: foreach_tensor_lerp_scalarlist_cuda_
autogen: _foreach_lerp.ScalarList_out
- func: _foreach_lgamma(Tensor[] self) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function

View File

@ -232,9 +232,12 @@ aten::_foreach_frac_
aten::_foreach_lerp.List
aten::_foreach_lerp.List_out
aten::_foreach_lerp.Scalar
aten::_foreach_lerp.ScalarList
aten::_foreach_lerp.ScalarList_out
aten::_foreach_lerp.Scalar_out
aten::_foreach_lerp_.List
aten::_foreach_lerp_.Scalar
aten::_foreach_lerp_.ScalarList
aten::_foreach_lgamma
aten::_foreach_lgamma.out
aten::_foreach_lgamma_

View File

@ -1534,6 +1534,8 @@ def check_autodiff_sample(op, sample, dtype, is_inplace):
or (isinstance(sample.args[-1], complex))
)
if rhs_arg_has_complex_number and dtype == torch.float64:
if op.name == "_foreach_lerp":
return False, "value cannot be converted to type double without overflow"
if op.name in (
"_foreach_clamp_max",
"_foreach_clamp_min",

View File

@ -749,6 +749,20 @@ def _foreach_lerp_scalar(
)
@register_decomposition(aten._foreach_lerp.ScalarList)
def _foreach_lerp_scalarlist(
start_tensors: List[torch.Tensor],
end_tensors: List[torch.Tensor],
scalars: List[torch.types.Number],
) -> List[torch.Tensor]:
return aten._foreach_add.List(
start_tensors,
aten._foreach_mul.ScalarList(
aten._foreach_sub.List(end_tensors, start_tensors), scalars
),
)
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
@register_decomposition(aten.miopen_batch_norm)
def miopen_batch_norm(

View File

@ -541,9 +541,7 @@ def _multi_tensor_adafactor(
]
torch._foreach_mul_(row_means, row_means)
torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads])
torch._foreach_mul_(device_row_vars, beta2_ts)
torch._foreach_mul_(row_means, one_minus_beta2_ts)
torch._foreach_add_(device_row_vars, row_means)
torch._foreach_lerp_(device_row_vars, row_means, one_minus_beta2_ts)
del row_means
# same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
@ -552,9 +550,7 @@ def _multi_tensor_adafactor(
]
torch._foreach_mul_(col_means, col_means)
torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads])
torch._foreach_mul_(device_col_vars, beta2_ts)
torch._foreach_mul_(col_means, one_minus_beta2_ts)
torch._foreach_add_(device_col_vars, col_means)
torch._foreach_lerp_(device_col_vars, col_means, one_minus_beta2_ts)
del col_means
var_estimates = [
@ -574,9 +570,7 @@ def _multi_tensor_adafactor(
), "variance should be defined when grad is a vector"
grads_squared = torch._foreach_mul(device_grads, device_grads)
torch._foreach_mul_(device_variances, beta2_ts)
torch._foreach_mul_(grads_squared, one_minus_beta2_ts)
torch._foreach_add_(device_variances, grads_squared)
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)
del grads_squared
# avoid writing into variance during update

View File

@ -11287,7 +11287,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
foreach_other_op_db: List[ForeachFuncInfo] = [
ForeachFuncInfo(
"lerp",
sample_inputs_func=foreach_inputs_sample_func(3, True, False),
sample_inputs_func=foreach_inputs_sample_func(3, True, True),
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
supports_autograd=True,
supports_inplace_autograd=True,
@ -11317,8 +11317,30 @@ foreach_other_op_db: List[ForeachFuncInfo] = [
"test_dispatch_symbolic_meta_inplace",
dtypes=integral_types_and(torch.bool),
),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=integral_types_and(torch.bool)),
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_meta_inplace",
dtypes=integral_types_and(torch.bool),
),
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_meta_outplace",
dtypes=integral_types_and(torch.bool),
),
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_dispatch_symbolic_meta_inplace_all_strides",
dtypes=integral_types_and(torch.bool),
),
DecorateInfo(
unittest.expectedFailure,
"TestMeta",
"test_dispatch_symbolic_meta_outplace_all_strides",
dtypes=integral_types_and(torch.bool),
),
),
),
]