mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
xe: concat: add ext_idx calculation for blocked write=1 cases
This commit is contained in:
@ -140,6 +140,11 @@ static status_t normalize(simple_params_t &conf,
|
||||
rt_conf.gws_d[1] = extern_axis / rt_conf.read_overlap;
|
||||
rt_conf.gws_d[2] = concat_dim_size;
|
||||
|
||||
const bool is_blocked_write1 = (conf.n_blocks
|
||||
&& conf.write_block <= static_cast<dim_t>(data_type_size));
|
||||
VDISPATCH_CONCAT_IC(IMPLICATION(is_blocked_write1, data_type_size <= 2),
|
||||
VERBOSE_BLOCKING_FAIL, "unsupported blocked write1 data type");
|
||||
|
||||
// Lots of zero padding byte writes -- very costly in this kernel
|
||||
const bool padding_ok = !(conf.write_block * conf.data_type_size == 1
|
||||
&& 4 * dst_md.dims[axis::concat]
|
||||
@ -399,9 +404,12 @@ void push_idx_kernel_args(compute::kernel_arg_list_t &partial_list,
|
||||
|
||||
// Workgroup reads may extend past the concat dimension, so we must also
|
||||
// consider the external axis when computing write indices
|
||||
const bool is_blocked_write1 = (conf.n_blocks
|
||||
&& conf.write_block <= static_cast<dim_t>(conf.data_type_size));
|
||||
|
||||
bool must_compute_ext_idx
|
||||
= (rt_conf.read_overlap * rt_conf.gws0_block > rt_conf.inner_axis)
|
||||
|| cutoff;
|
||||
|| cutoff || is_blocked_write1;
|
||||
partial_list.append(static_cast<std::uint8_t>(must_compute_ext_idx));
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user