mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add meta registration for _foreach_norm (#118604)
This PR also fixes the discrepancy between _foreach_norm fast path and slow path, where storage_offsets will be different between the lists of tensors. Here are some profile results showing that we aren't significantly slower. Do note that we're replacing N `as_strided`/`select` calls to N `empty` calls. For script: ``` import torch ts = [torch.rand(32, 16, device="cuda") for _ in range(128)] with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: res = torch._foreach_norm(ts) print(p.key_averages().table(sort_by="cpu_time_total")) ``` OG baseline: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (7cf98987)]$ python playground2.py STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:314] Completed Stage: Warm Up STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:320] Completed Stage: Collection STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:324] Completed Stage: Post Processing ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_norm 25.36% 4.209ms 99.94% 16.586ms 16.586ms 8.000us 88.89% 9.000us 9.000us 1 cudaLaunchKernel 61.21% 10.159ms 61.21% 10.159ms 2.540ms 0.000us 0.00% 0.000us 0.000us 4 aten::zeros 0.43% 71.000us 58.35% 9.683ms 9.683ms 0.000us 0.00% 1.000us 1.000us 1 aten::zero_ 0.33% 55.000us 57.35% 9.517ms 9.517ms 0.000us 0.00% 1.000us 1.000us 1 aten::fill_ 0.42% 69.000us 57.01% 9.462ms 9.462ms 1.000us 11.11% 1.000us 1.000us 1 aten::select 8.04% 1.335ms 11.29% 1.873ms 14.633us 0.000us 0.00% 0.000us 0.000us 128 aten::as_strided 3.24% 538.000us 3.24% 538.000us 4.203us 0.000us 0.00% 0.000us 0.000us 128 aten::empty 0.90% 150.000us 0.90% 150.000us 75.000us 0.000us 0.00% 0.000us 0.000us 2 cudaDeviceSynchronize 0.06% 10.000us 0.06% 10.000us 10.000us 0.000us 0.00% 0.000us 0.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 11.11% 1.000us 1.000us 1 void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 66.67% 6.000us 3.000us 2 void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 22.22% 2.000us 2.000us 1 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 16.596ms Self CUDA time total: 9.000us ``` And here's after this PR: ``` STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:314] Completed Stage: Warm Up STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:320] Completed Stage: Collection STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:324] Completed Stage: Post Processing ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_norm 30.95% 4.653ms 99.95% 15.026ms 15.026ms 9.000us 90.00% 10.000us 10.000us 1 cudaLaunchKernel 52.41% 7.879ms 52.41% 7.879ms 1.970ms 0.000us 0.00% 0.000us 0.000us 4 aten::zeros 0.39% 58.000us 48.29% 7.260ms 7.260ms 0.000us 0.00% 1.000us 1.000us 1 aten::zero_ 0.35% 53.000us 47.25% 7.103ms 7.103ms 0.000us 0.00% 1.000us 1.000us 1 aten::fill_ 0.43% 65.000us 46.90% 7.050ms 7.050ms 1.000us 10.00% 1.000us 1.000us 1 aten::empty 15.42% 2.318ms 15.42% 2.318ms 17.969us 0.000us 0.00% 0.000us 0.000us 129 cudaDeviceSynchronize 0.05% 7.000us 0.05% 7.000us 7.000us 0.000us 0.00% 0.000us 0.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 10.00% 1.000us 1.000us 1 void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 60.00% 6.000us 3.000us 2 void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 30.00% 3.000us 3.000us 1 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 15.033ms Self CUDA time total: 10.000us ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118604 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
51e096114b
commit
b8bb12cd45
@ -112,7 +112,7 @@ template <
|
||||
typename opmath_t = at::opmath_type<T>>
|
||||
__global__ void lpnorm_cleanup(
|
||||
const opmath_t* output_per_tensor,
|
||||
T* ret_per_tensor,
|
||||
TensorListMetadata<1> vec_res_meta,
|
||||
int max_chunks_per_tensor) {
|
||||
__shared__ opmath_t vals[512];
|
||||
|
||||
@ -130,7 +130,7 @@ __global__ void lpnorm_cleanup(
|
||||
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
|
||||
: at::native::cuda_utils::BlockReduceMax(val, vals);
|
||||
if (threadIdx.x == 0) {
|
||||
ret_per_tensor[blockIdx.x] =
|
||||
*(T*)vec_res_meta.addresses[0][blockIdx.x] =
|
||||
norm_type == NormType::L1 || norm_type == NormType::LInf
|
||||
? final_val
|
||||
: ::sqrt(final_val);
|
||||
@ -180,7 +180,12 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
auto output_per_tensor = at::zeros(
|
||||
{ntensors * max_chunks_per_tensor},
|
||||
options.dtype(toOpMathType(tensors[0].scalar_type())));
|
||||
auto ret_per_tensor = at::empty({ntensors}, options);
|
||||
|
||||
std::vector<at::Tensor> vec_res;
|
||||
vec_res.reserve(ntensors);
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
vec_res.push_back(at::empty({}, options));
|
||||
}
|
||||
|
||||
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
|
||||
if (p == static_cast<double>(1)) {
|
||||
@ -200,9 +205,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
device_of(output_per_tensor));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
TensorListMetadata<1> vecResMeta;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
vecResMeta.addresses[0][i] =
|
||||
vec_res[i].mutable_data_ptr<scalar_t>();
|
||||
}
|
||||
lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
vecResMeta,
|
||||
max_chunks_per_tensor);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
@ -223,9 +233,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
device_of(output_per_tensor));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
TensorListMetadata<1> vecResMeta;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
vecResMeta.addresses[0][i] =
|
||||
vec_res[i].mutable_data_ptr<scalar_t>();
|
||||
}
|
||||
lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
vecResMeta,
|
||||
max_chunks_per_tensor);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
@ -246,10 +261,15 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
device_of(output_per_tensor));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
TensorListMetadata<1> vecResMeta;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
vecResMeta.addresses[0][i] =
|
||||
vec_res[i].mutable_data_ptr<scalar_t>();
|
||||
}
|
||||
lpnorm_cleanup<scalar_t, NormType::LInf>
|
||||
<<<ntensors, 512, 0, stream>>>(
|
||||
output_per_tensor.const_data_ptr<opmath_t>(),
|
||||
ret_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
vecResMeta,
|
||||
max_chunks_per_tensor);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
@ -267,7 +287,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
int i = 0;
|
||||
for (const auto& t : tensors) {
|
||||
if (t.numel() != 0) {
|
||||
result.emplace_back(ret_per_tensor[i]);
|
||||
result.emplace_back(vec_res[i]);
|
||||
i++;
|
||||
} else {
|
||||
result.emplace_back(at::zeros({}, options));
|
||||
|
@ -401,22 +401,24 @@ def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
|
||||
if not isinstance(r, torch.Tensor):
|
||||
continue
|
||||
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
|
||||
test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
|
||||
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
|
||||
test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
|
||||
test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
|
||||
# See https://github.com/pytorch/pytorch/issues/78050
|
||||
if should_check_strides(func) == CheckStrides.ALL:
|
||||
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
|
||||
test_assert(same_strides, f"but real stride was {r.stride()}")
|
||||
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
|
||||
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
|
||||
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
|
||||
test_assert(same_strides, f"but real stride was {r.stride()}")
|
||||
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
|
||||
test_assert(
|
||||
meta_r.storage_offset() == r.storage_offset(),
|
||||
f"but real storage_offset was {r.storage_offset()}")
|
||||
test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
|
||||
f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
|
||||
test_assert(meta_r.requires_grad == r.requires_grad,
|
||||
f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
|
||||
if func not in CHECK_CONJ_SKIPS:
|
||||
test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
|
||||
test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
|
||||
test_assert(meta_r.is_conj() == r.is_conj(),
|
||||
f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
|
||||
test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
|
||||
|
||||
|
||||
# This environment variable controls whether or not we print expected failure
|
||||
|
@ -3070,6 +3070,7 @@ def register_meta_foreach(ops):
|
||||
aten._foreach_log1p,
|
||||
aten._foreach_log2,
|
||||
aten._foreach_neg,
|
||||
aten._foreach_norm,
|
||||
aten._foreach_reciprocal,
|
||||
aten._foreach_round,
|
||||
aten._foreach_sigmoid,
|
||||
|
@ -9688,11 +9688,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
Reference in New Issue
Block a user