xe: concat: add ext_idx calculation for blocked write=1 cases

This commit is contained in:
syurkevi
2025-10-16 12:15:23 -07:00
parent 3a756b982b
commit 540067c082

View File

@ -140,6 +140,11 @@ static status_t normalize(simple_params_t &conf,
rt_conf.gws_d[1] = extern_axis / rt_conf.read_overlap;
rt_conf.gws_d[2] = concat_dim_size;
const bool is_blocked_write1 = (conf.n_blocks
&& conf.write_block <= static_cast<dim_t>(data_type_size));
VDISPATCH_CONCAT_IC(IMPLICATION(is_blocked_write1, data_type_size <= 2),
VERBOSE_BLOCKING_FAIL, "unsupported blocked write1 data type");
// Lots of zero padding byte writes -- very costly in this kernel
const bool padding_ok = !(conf.write_block * conf.data_type_size == 1
&& 4 * dst_md.dims[axis::concat]
@ -399,9 +404,12 @@ void push_idx_kernel_args(compute::kernel_arg_list_t &partial_list,
// Workgroup reads may extend past the concat dimension, so we must also
// consider the external axis when computing write indices
const bool is_blocked_write1 = (conf.n_blocks
&& conf.write_block <= static_cast<dim_t>(conf.data_type_size));
bool must_compute_ext_idx
= (rt_conf.read_overlap * rt_conf.gws0_block > rt_conf.inner_axis)
|| cutoff;
|| cutoff || is_blocked_write1;
partial_list.append(static_cast<std::uint8_t>(must_compute_ext_idx));
}