Add meta registration for _foreach_norm (#118604)

This PR also fixes the discrepancy between _foreach_norm fast path and slow path, where storage_offsets will be different between the lists of tensors. Here are some profile results showing that we aren't significantly slower. Do note that we're replacing N `as_strided`/`select` calls to N `empty` calls.

For script:
```
import torch

ts = [torch.rand(32, 16, device="cuda") for _ in range(128)]

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as p:
    res = torch._foreach_norm(ts)
print(p.key_averages().table(sort_by="cpu_time_total"))
```

OG baseline:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (7cf98987)]$ python playground2.py
STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-01-30 13:16:48 2740431:2740431 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                    aten::_foreach_norm        25.36%       4.209ms        99.94%      16.586ms      16.586ms       8.000us        88.89%       9.000us       9.000us             1
                                       cudaLaunchKernel        61.21%      10.159ms        61.21%      10.159ms       2.540ms       0.000us         0.00%       0.000us       0.000us             4
                                            aten::zeros         0.43%      71.000us        58.35%       9.683ms       9.683ms       0.000us         0.00%       1.000us       1.000us             1
                                            aten::zero_         0.33%      55.000us        57.35%       9.517ms       9.517ms       0.000us         0.00%       1.000us       1.000us             1
                                            aten::fill_         0.42%      69.000us        57.01%       9.462ms       9.462ms       1.000us        11.11%       1.000us       1.000us             1
                                           aten::select         8.04%       1.335ms        11.29%       1.873ms      14.633us       0.000us         0.00%       0.000us       0.000us           128
                                       aten::as_strided         3.24%     538.000us         3.24%     538.000us       4.203us       0.000us         0.00%       0.000us       0.000us           128
                                            aten::empty         0.90%     150.000us         0.90%     150.000us      75.000us       0.000us         0.00%       0.000us       0.000us             2
                                  cudaDeviceSynchronize         0.06%      10.000us         0.06%      10.000us      10.000us       0.000us         0.00%       0.000us       0.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us        11.11%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us        66.67%       6.000us       3.000us             2
void at::native::lpnorm_cleanup<float, (at::native::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us        22.22%       2.000us       2.000us             1
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 16.596ms
Self CUDA time total: 9.000us
```

And here's after this PR:
```
STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-05 08:27:02 1127843:1127843 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                    aten::_foreach_norm        30.95%       4.653ms        99.95%      15.026ms      15.026ms       9.000us        90.00%      10.000us      10.000us             1
                                       cudaLaunchKernel        52.41%       7.879ms        52.41%       7.879ms       1.970ms       0.000us         0.00%       0.000us       0.000us             4
                                            aten::zeros         0.39%      58.000us        48.29%       7.260ms       7.260ms       0.000us         0.00%       1.000us       1.000us             1
                                            aten::zero_         0.35%      53.000us        47.25%       7.103ms       7.103ms       0.000us         0.00%       1.000us       1.000us             1
                                            aten::fill_         0.43%      65.000us        46.90%       7.050ms       7.050ms       1.000us        10.00%       1.000us       1.000us             1
                                            aten::empty        15.42%       2.318ms        15.42%       2.318ms      17.969us       0.000us         0.00%       0.000us       0.000us           129
                                  cudaDeviceSynchronize         0.05%       7.000us         0.05%       7.000us       7.000us       0.000us         0.00%       0.000us       0.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us        10.00%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us        60.00%       6.000us       3.000us             2
void at::native::lpnorm_cleanup<float, (at::native::...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us        30.00%       3.000us       3.000us             1
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 15.033ms
Self CUDA time total: 10.000us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118604
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-02-05 07:37:20 -08:00
committed by PyTorch MergeBot
parent 51e096114b
commit b8bb12cd45
4 changed files with 38 additions and 19 deletions

View File

@ -112,7 +112,7 @@ template <
typename opmath_t = at::opmath_type<T>>
__global__ void lpnorm_cleanup(
const opmath_t* output_per_tensor,
T* ret_per_tensor,
TensorListMetadata<1> vec_res_meta,
int max_chunks_per_tensor) {
__shared__ opmath_t vals[512];
@ -130,7 +130,7 @@ __global__ void lpnorm_cleanup(
? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
: at::native::cuda_utils::BlockReduceMax(val, vals);
if (threadIdx.x == 0) {
ret_per_tensor[blockIdx.x] =
*(T*)vec_res_meta.addresses[0][blockIdx.x] =
norm_type == NormType::L1 || norm_type == NormType::LInf
? final_val
: ::sqrt(final_val);
@ -180,7 +180,12 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
auto output_per_tensor = at::zeros(
{ntensors * max_chunks_per_tensor},
options.dtype(toOpMathType(tensors[0].scalar_type())));
auto ret_per_tensor = at::empty({ntensors}, options);
std::vector<at::Tensor> vec_res;
vec_res.reserve(ntensors);
for (int i = 0; i < ntensors; i++) {
vec_res.push_back(at::empty({}, options));
}
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
if (p == static_cast<double>(1)) {
@ -200,9 +205,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
TensorListMetadata<1> vecResMeta;
for (int i = 0; i < ntensors; i++) {
vecResMeta.addresses[0][i] =
vec_res[i].mutable_data_ptr<scalar_t>();
}
lpnorm_cleanup<scalar_t, NormType::L1><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
vecResMeta,
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
@ -223,9 +233,14 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
TensorListMetadata<1> vecResMeta;
for (int i = 0; i < ntensors; i++) {
vecResMeta.addresses[0][i] =
vec_res[i].mutable_data_ptr<scalar_t>();
}
lpnorm_cleanup<scalar_t, NormType::L2><<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
vecResMeta,
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
@ -246,10 +261,15 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
const at::cuda::OptionalCUDAGuard device_guard(
device_of(output_per_tensor));
auto stream = at::cuda::getCurrentCUDAStream();
TensorListMetadata<1> vecResMeta;
for (int i = 0; i < ntensors; i++) {
vecResMeta.addresses[0][i] =
vec_res[i].mutable_data_ptr<scalar_t>();
}
lpnorm_cleanup<scalar_t, NormType::LInf>
<<<ntensors, 512, 0, stream>>>(
output_per_tensor.const_data_ptr<opmath_t>(),
ret_per_tensor.mutable_data_ptr<scalar_t>(),
vecResMeta,
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
@ -267,7 +287,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
int i = 0;
for (const auto& t : tensors) {
if (t.numel() != 0) {
result.emplace_back(ret_per_tensor[i]);
result.emplace_back(vec_res[i]);
i++;
} else {
result.emplace_back(at::zeros({}, options));

View File

@ -401,22 +401,24 @@ def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
if not isinstance(r, torch.Tensor):
continue
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
test_assert(meta_r.dtype == r.dtype, f"for element {i}, was {meta_r.dtype} but real dtype was {r.dtype}")
test_assert(meta_r.shape == r.shape, f"for element {i}, was {meta_r.shape} but real shape was {r.shape}")
# See https://github.com/pytorch/pytorch/issues/78050
if should_check_strides(func) == CheckStrides.ALL:
same_strides, _ = torch._prims_common.check_all_strides(meta_r, r)
test_assert(same_strides, f"but real stride was {r.stride()}")
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
elif should_check_strides(func) == CheckStrides.SIGNIFICANT:
same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
test_assert(same_strides, f"but real stride was {r.stride()}")
test_assert(same_strides, f"for element {i}, was {meta_r.stride()} but real stride was {r.stride()}")
test_assert(
meta_r.storage_offset() == r.storage_offset(),
f"but real storage_offset was {r.storage_offset()}")
test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
f"for element {i}, was {meta_r.storage_offset()} but real storage_offset was {r.storage_offset()}")
test_assert(meta_r.requires_grad == r.requires_grad,
f"for element {i}, was {meta_r.requires_grad} but real requires_grad was {r.requires_grad}")
if func not in CHECK_CONJ_SKIPS:
test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
test_assert(meta_r.is_conj() == r.is_conj(),
f"for element {i}, was {meta_r.is_conj()} but real is_conj was {r.is_conj()}")
test_assert(meta_r.is_neg() == r.is_neg(), f"for element {i}, was {meta_r.is_neg()} but real is_neg was {r.is_neg()}")
# This environment variable controls whether or not we print expected failure

View File

@ -3070,6 +3070,7 @@ def register_meta_foreach(ops):
aten._foreach_log1p,
aten._foreach_log2,
aten._foreach_neg,
aten._foreach_norm,
aten._foreach_reciprocal,
aten._foreach_round,
aten._foreach_sigmoid,

View File

@ -9688,11 +9688,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
),
),
]