mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[4/N] Avoid copy in std::get (#142285)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/142285 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
2cc01cc6d3
commit
a108b282ff
@ -353,7 +353,7 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
|
||||
auto t1 = at::zeros({1});
|
||||
auto t2 = at::zeros({1});
|
||||
|
||||
auto [t1_out, t2_out] = func.call<
|
||||
const auto [t1_out, t2_out] = func.call<
|
||||
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
|
||||
>(dummy, CPU_TEST_SET, s1, s2, t1, t2);
|
||||
|
||||
|
@ -181,8 +181,8 @@ convolution_backward_input_batch_rule(
|
||||
const auto result = at::convolution_backward_symint(
|
||||
grad_output_, dummy_input, weight_, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups * batch_size, mask);
|
||||
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
|
||||
return std::make_tuple(grad_input, 1);
|
||||
auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
|
||||
return std::make_tuple(std::move(grad_input), 1);
|
||||
} else if (grad_output_bdim && !weight_bdim) {
|
||||
// BNO, OI -> (BN)O, OI -> (BN)I
|
||||
// transposed is the same.
|
||||
@ -192,8 +192,8 @@ convolution_backward_input_batch_rule(
|
||||
const auto result = at::convolution_backward_symint(
|
||||
grad_output_, dummy_input, weight, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, mask);
|
||||
const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
|
||||
return std::make_tuple(grad_input, 0);
|
||||
auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result));
|
||||
return std::make_tuple(std::move(grad_input), 0);
|
||||
} else if (!grad_output_bdim && weight_bdim) {
|
||||
const auto batch_size = weight.size(*weight_bdim);
|
||||
if (groups == 1) {
|
||||
@ -359,7 +359,6 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
||||
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
|
||||
const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
|
||||
c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
|
||||
const auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
|
||||
@ -369,14 +368,14 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
return at::convolution_backward_symint(
|
||||
grad_output_, input_, weight_, bias_sizes_opt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, output_mask);
|
||||
dilation, transposed, output_padding, std::move(groups), output_mask);
|
||||
}
|
||||
|
||||
auto [grad_output, grad_output_bdim] = unwrapTensorAtLevel(grad_output_, cur_level);
|
||||
auto [input, input_bdim] = unwrapTensorAtLevel(input_, cur_level);
|
||||
auto [weight, weight_bdim] = unwrapTensorAtLevel(weight_, cur_level);
|
||||
|
||||
const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
|
||||
auto grad_bias = compute_grad_bias(grad_output_, output_mask);
|
||||
output_mask[2] = false;
|
||||
|
||||
// TODO: A little bird says that unfold + matmul is actually faster than
|
||||
@ -408,14 +407,14 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
||||
grad_output, input, weight, std::nullopt, stride, padding, dilation,
|
||||
transposed, output_padding, batch_size * groups, output_mask);
|
||||
// N(BI), (BO)I -> NBI, BOI
|
||||
const auto grad_input = output_mask[0] ?
|
||||
auto grad_input = output_mask[0] ?
|
||||
reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
|
||||
const auto grad_weight = output_mask[1] ?
|
||||
auto grad_weight = output_mask[1] ?
|
||||
reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
|
||||
return std::make_tuple(
|
||||
output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
|
||||
output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
|
||||
grad_bias);
|
||||
output_mask[0] ? makeBatched(std::move(grad_input), 1, cur_level) : std::move(grad_input),
|
||||
output_mask[1] ? makeBatched(std::move(grad_weight), 0, cur_level) : std::move(grad_weight),
|
||||
std::move(grad_bias));
|
||||
}
|
||||
|
||||
Tensor grad_input;
|
||||
@ -426,7 +425,7 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
||||
input, input_bdim,
|
||||
weight, weight_bdim,
|
||||
stride, padding, dilation, transposed, output_padding, groups);
|
||||
grad_input = makeBatched(tensor, bdim, cur_level);
|
||||
grad_input = makeBatched(std::move(tensor), bdim, cur_level);
|
||||
}
|
||||
|
||||
Tensor grad_weight;
|
||||
@ -437,9 +436,9 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
||||
input, input_bdim,
|
||||
weight, weight_bdim,
|
||||
stride, padding, dilation, transposed, output_padding, groups);
|
||||
grad_weight = makeBatched(tensor, bdim, cur_level);
|
||||
grad_weight = makeBatched(std::move(tensor), bdim, cur_level);
|
||||
}
|
||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||
return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias));
|
||||
|
||||
// Someone's definitely going to find a problem with this batching rule so
|
||||
// I'm leaving the following fallback if we need it back.
|
||||
|
@ -121,7 +121,7 @@ batch_norm_batch_rule(
|
||||
result0 = result0 + bias_;
|
||||
}
|
||||
result0 = result0.transpose(1, 2); // [B0, B, C, *], because some arg must have been batched, the output must be batched
|
||||
return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim);
|
||||
return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
|
||||
}
|
||||
|
||||
template<typename F, F Func>
|
||||
|
@ -727,8 +727,7 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
|
||||
const hidden_type& hidden,
|
||||
const cell_params& params,
|
||||
bool pre_compute_input = false) const override {
|
||||
const auto& hx = std::get<0>(hidden);
|
||||
const auto& cx = std::get<1>(hidden);
|
||||
const auto& [hx, cx] = hidden;
|
||||
|
||||
if (input.is_cuda() || input.is_xpu() || input.is_privateuseone()) {
|
||||
TORCH_CHECK(!pre_compute_input);
|
||||
|
@ -1052,10 +1052,7 @@ static inline void grid_sample_2d_grid_slice_iterator(
|
||||
std::min(step, len * 2));
|
||||
auto vec2 = Vec::loadu(grid_ptr + grid_offset + step,
|
||||
std::max(static_cast<int64_t>(0), len * 2 - step));
|
||||
auto vec_xy_pair = deinterleave2(vec1, vec2);
|
||||
|
||||
auto x = std::get<0>(vec_xy_pair);
|
||||
auto y = std::get<1>(vec_xy_pair);
|
||||
auto [x, y] = deinterleave2(vec1, vec2);
|
||||
|
||||
// make sure that x and y are valid grid sample locations
|
||||
if (len < step) {
|
||||
|
@ -67,13 +67,10 @@ __global__ void distribution_elementwise_grid_stride_kernel(int64_t numel,
|
||||
PhiloxCudaState philox_args,
|
||||
const dist_t dist_func,
|
||||
const transform_t transform_func) {
|
||||
auto seeds = at::cuda::philox::unpack(philox_args);
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(std::get<0>(seeds),
|
||||
idx,
|
||||
std::get<1>(seeds),
|
||||
&state);
|
||||
curand_init(seed, idx, offset, &state);
|
||||
|
||||
int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
|
||||
blockDim.x * gridDim.x * unroll_factor;
|
||||
|
@ -56,13 +56,10 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
|
||||
using LoadT = memory::aligned_vector<scalar_t, VEC>;
|
||||
using MaskLoadT = memory::aligned_vector<mask_t, VEC>;
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(philox_args);
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
|
||||
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(std::get<0>(seeds),
|
||||
idx,
|
||||
std::get<1>(seeds),
|
||||
&state);
|
||||
curand_init(seed, idx, offset, &state);
|
||||
|
||||
// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
|
||||
// in the vec=2 and vec=4 cases.
|
||||
@ -138,13 +135,10 @@ fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t, IndexType> a,
|
||||
cuda::detail::TensorInfo<mask_t, IndexType> c,
|
||||
IndexType totalElements, accscalar_t p,
|
||||
PhiloxCudaState philox_args) {
|
||||
auto seeds = at::cuda::philox::unpack(philox_args);
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
|
||||
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(std::get<0>(seeds),
|
||||
idx,
|
||||
std::get<1>(seeds),
|
||||
&state);
|
||||
curand_init(seed, idx, offset, &state);
|
||||
accscalar_t scale = 1.0 / p;
|
||||
|
||||
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
|
||||
|
@ -39,10 +39,10 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self
|
||||
// Cannot reassign to mask_temp and self_temp here! if they are
|
||||
// owning and expand_outplace returns a borrow, the returned borrow
|
||||
// would dangle.
|
||||
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
|
||||
auto [mask_expanded, self_expanded] = expand_outplace(*mask_temp, *self_temp);
|
||||
at::cuda::index_out(
|
||||
result, *std::get<1>(mask_self_expanded),
|
||||
c10::List<std::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}));
|
||||
result, *self_expanded,
|
||||
c10::List<std::optional<at::Tensor>>({*std::move(mask_expanded)}));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T
|
||||
|
||||
// do random permutation inside each island.
|
||||
data += tid;
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
|
||||
const auto [seed, offset] = at::cuda::philox::unpack(philox_args);
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed, tid, offset, &state);
|
||||
for (int i = island_size - 1; i > 0; i--) {
|
||||
|
@ -81,11 +81,7 @@ inline void _rrelu_with_noise_cuda_train(
|
||||
|
||||
int64_t numel = input.numel();
|
||||
const int unroll_factor = std::is_same_v<scalar_t, double> ? 2 : 4;
|
||||
auto execution_policy = calc_execution_policy(numel, unroll_factor);
|
||||
|
||||
auto counter_offset = std::get<0>(execution_policy);
|
||||
auto grid = std::get<1>(execution_policy);
|
||||
auto block = std::get<2>(execution_policy);
|
||||
auto [counter_offset, grid, block] = calc_execution_policy(numel, unroll_factor);
|
||||
|
||||
auto gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||
generator, cuda::detail::getDefaultCUDAGenerator());
|
||||
|
@ -2560,9 +2560,10 @@ std::pair<Tensor, hidden_type> _cudnn_impl(
|
||||
dropout_state.buffer);
|
||||
|
||||
return {
|
||||
std::get<0>(cudnn_output),
|
||||
std::move(std::get<0>(cudnn_output)),
|
||||
pack_hidden<hidden_type>(
|
||||
std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
|
||||
std::move(std::get<1>(cudnn_output)),
|
||||
std::move(std::get<2>(cudnn_output)))};
|
||||
}
|
||||
|
||||
template <typename hidden_type>
|
||||
@ -2621,9 +2622,10 @@ std::pair<Tensor, hidden_type> _cudnn_impl(
|
||||
dropout_state.buffer);
|
||||
|
||||
return {
|
||||
std::get<0>(cudnn_output),
|
||||
std::move(std::get<0>(cudnn_output)),
|
||||
pack_hidden<hidden_type>(
|
||||
std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
|
||||
std::move(std::get<1>(cudnn_output)),
|
||||
std::move(std::get<2>(cudnn_output)))};
|
||||
}
|
||||
|
||||
#define ONE_HIDDEN_RNN(NAME, MODE) \
|
||||
|
@ -144,7 +144,6 @@ Tensor& _index_put_impl_quantized_cpu_(Tensor & self, const torch::List<std::opt
|
||||
value_ = value.to(self.device());
|
||||
}
|
||||
at::assert_no_overlap(self, value);
|
||||
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
|
||||
for (const std::optional<Tensor>& index: indices) {
|
||||
if (index.has_value()) {
|
||||
at::assert_no_overlap(self, *index);
|
||||
|
@ -400,7 +400,7 @@ register_conv_params() {
|
||||
},
|
||||
// __setstate__ takes c10::IValue because we support parsing historical
|
||||
// serialization versions.
|
||||
[](c10::IValue v)
|
||||
[](const c10::IValue& v)
|
||||
-> c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> { // __setstate__
|
||||
ConvParamsSerializationTypeV3 state = parse_conv_serialized_state<kSpatialDim>(v);
|
||||
return deserialize_conv<kSpatialDim>(state);
|
||||
|
@ -43,16 +43,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
|
||||
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
|
||||
pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
|
||||
auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
|
||||
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
|
||||
bidb, bidh, tidx, params.h);
|
||||
|
||||
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
||||
// exit early and no one saves the rng state.
|
||||
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||
if (params.philox_args.captured_) {
|
||||
*params.seed = std::get<0>(seed_offset);
|
||||
*params.extragraph_offset = std::get<1>(seed_offset);
|
||||
*params.seed = seed;
|
||||
*params.extragraph_offset = offset;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ class C10_CUDA_API FreeMemoryCallback {
|
||||
|
||||
C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
|
||||
#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
|
||||
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
|
||||
C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__)
|
||||
} // namespace c10
|
||||
//
|
||||
// TODO: Turn this into an honest to goodness class. I briefly attempted to do
|
||||
|
@ -261,6 +261,6 @@ bool CudaIPCCollect() {
|
||||
|
||||
namespace c10 {
|
||||
namespace {
|
||||
REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
|
||||
REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback)
|
||||
}
|
||||
} // namespace c10
|
||||
|
@ -174,6 +174,6 @@ TORCH_LIBRARY(cuda, m) {
|
||||
.def("record", &CUDAEvent::record)
|
||||
.def("synchronize", &CUDAEvent::synchronize)
|
||||
.def("wait", &CUDAEvent::wait);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
||||
|
Reference in New Issue
Block a user