mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Improve perf for elementwise broadcast with mixed dtype (#163562)
* Unroll loops manually to hide memory access latency Co-author: @amd-hhashemi Pull Request resolved: https://github.com/pytorch/pytorch/pull/163562 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
fde929c8a8
commit
2aadcea05c
@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
||||
dtypes[i] = iter.dtype(i);
|
||||
}
|
||||
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
||||
#ifdef USE_ROCM
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
auto offsets0 = offset_calc.get(idx);
|
||||
auto offsets1 = offset_calc.get(idx + grp_sz);
|
||||
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
|
||||
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
|
||||
void* out0 = data[0] + offsets0[0];
|
||||
void* out1 = data[0] + offsets1[0];
|
||||
void* out2 = data[0] + offsets2[0];
|
||||
void* out3 = data[0] + offsets3[0];
|
||||
arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1);
|
||||
arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1);
|
||||
arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1);
|
||||
arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out0, result0);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out1, result1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out2, result2);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out3, result3);
|
||||
} else {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
}
|
||||
});
|
||||
#else
|
||||
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
|
||||
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user