[ghstack-poisoned]
This commit is contained in:
ZhiweiYan-96
2025-05-23 03:28:41 +00:00
parent 60d2cb79ce
commit ba9b9162d9
3 changed files with 56 additions and 58 deletions

View File

@ -4,12 +4,12 @@ using namespace at::native::onednn;
namespace at::native::xpu {
onednn::Attr unary_attr_with_arg(
onednn::Attr& unary_attr_with_arg(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars =
torch::List<std::optional<at::Scalar>>(),
std::optional<std::string_view> algorithm = std::nullopt,
onednn::Attr attr = Attr()) {
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm
) {
if (unary == "hardswish") {
return attr.append_post_eltwise(
1.0f, 1.f / 6.f, 1.f / 2.f, attr.kind_with_hardswish);
@ -23,6 +23,7 @@ onednn::Attr unary_attr_with_arg(
auto beta = scalars[1].get().toOptional<at::Scalar>().value().to<float>();
return attr.append_post_eltwise(1.0f, alpha, beta, attr.kind_with_clip);
} else if (unary == "gelu") {
TORCH_CHECK(algorithm.has_value(), "GELU algorithm is not specified");
enum dnnl::algorithm gelu_type;
if (algorithm.value() == "none") {
gelu_type = attr.kind_with_gelu_erf;
@ -42,7 +43,7 @@ onednn::Attr unary_attr_with_arg(
return attr;
}
onednn::Attr string_to_unary_attr(std::string_view unary, onednn::Attr attr) {
onednn::Attr& string_to_unary_attr(onednn::Attr& attr, std::string_view unary) {
if (unary == "relu") {
return attr.append_post_eltwise(1.0f, 0.0f, 0.0f, attr.kind_with_relu);
} else if (unary == "sigmoid") {
@ -50,29 +51,25 @@ onednn::Attr string_to_unary_attr(std::string_view unary, onednn::Attr attr) {
} else if (unary == "tanh") {
return attr.append_post_eltwise(1.0f, 0.0f, 0.0f, attr.kind_with_tanh);
} else if (unary == "hardswish") {
return unary_attr_with_arg(
"hardswish",
torch::List<std::optional<at::Scalar>>(),
std::nullopt,
attr);
return unary_attr_with_arg(attr, "hardswish", torch::List<std::optional<at::Scalar>>(), std::nullopt);
} else if (unary == "swish") {
return unary_attr_with_arg(
"silu", torch::List<std::optional<at::Scalar>>(), std::nullopt, attr);
return unary_attr_with_arg(attr, "silu", torch::List<std::optional<at::Scalar>>(), std::nullopt);
}
return attr;
}
onednn::Attr construct_unary_attr(
onednn::Attr& construct_unary_attr(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm,
Attr attr = Attr()) {
std::set<std::string_view> simple_unary = {
std::optional<std::string_view> algorithm
) {
static const std::set<std::string_view> simple_unary = {
"relu", "sigmoid", "tanh", "hardswish", "swish"};
if (simple_unary.find(unary) != simple_unary.end()) {
return string_to_unary_attr(unary, attr);
return string_to_unary_attr(attr, unary);
} else {
return unary_attr_with_arg(unary, scalars, algorithm, attr);
return unary_attr_with_arg(attr, unary, scalars, algorithm);
}
}

View File

@ -9,28 +9,27 @@
//
namespace at::native::xpu {
at::native::onednn::Attr unary_attr_with_arg(
at::native::onednn::Attr& unary_attr_with_arg(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm,
onednn::Attr attr);
std::optional<std::string_view> algorithm);
at::native::onednn::Attr string_to_unary_attr(
std::string_view unary,
onednn::Attr attr);
at::native::onednn::Attr& string_to_unary_attr(
onednn::Attr& attr,
std::string_view unary);
at::native::onednn::Attr construct_unary_attr(
at::native::onednn::Attr& construct_unary_attr(
onednn::Attr& attr,
std::string_view unary,
torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm,
onednn::Attr attr);
std::optional<std::string_view> algorithm);
template <bool is_matmul = false>
onednn::Attr construct_binary_attr(
onednn::Attr& construct_binary_attr(
onednn::Attr& attr,
std::string_view binary,
std::optional<at::Scalar> alpha,
const Tensor& other,
onednn::Attr attr) {
const Tensor& other) {
if (binary == "mul") {
attr.append_post_binary<is_matmul>(attr.kind_with_binary_mul, other);
} else if (binary == "sub") {

View File

@ -5,6 +5,22 @@
namespace at::native::xpu {
std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> collapse_in_out_dim(at::Tensor input, int64_t dim, at::Tensor weight){
// dim collapse, e.g. [B, M, K] -> [BM, K]
std::vector<int64_t> input_reshaped_size = (dim == 2) ? std::vector<int64_t>(input.size(0), input.size(1)) : std::vector<int64_t>{input.numel()/(input.size(input.dim()-1)), input.size(input.dim() -1)};
// [B, M, K] -> [B, M]
std::vector<int64_t> output_size(input.sizes().begin(), input.sizes().end() -1);
// [B, M, N]
output_size.push_back(weight.size(0));
// [BM, N]
std::vector<int64_t> output_reshaped_size{input_reshaped_size[0], weight.size(0)};
return {input_reshaped_size, output_size, output_reshaped_size};
}
Tensor linear_pointwise(
const Tensor& input_t, // [M, K] or [B, M, K]
const Tensor& weight_t, // [N, K]
@ -14,27 +30,17 @@ Tensor linear_pointwise(
std::optional<std::string_view> algorithm) {
onednn::Attr att;
const OptionalDeviceGuard device_guard(device_of(input_t));
att = construct_unary_attr(attr, scalars, algorithm, att);
att = construct_unary_attr(att, attr, scalars, algorithm);
auto input = input_t.contiguous();
auto input_size = input.sizes();
const int64_t dim = input.dim();
// dim collapse
// [B, M, K] -> [BM, K]
auto input_reshaped =
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
// [B, M]
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
// [BM, N]
output_size.push_back(weight_t.size(0));
auto [input_reshaped_size, output_size, output_reshaped_size] = collapse_in_out_dim(input, dim, weight_t);
Tensor output = at::empty(output_size, input.options());
if (dim != 2) {
// collapse output
std::vector<int64_t> output_size_reshaped = {
input_reshaped.size(0), weight_t.size(0)};
output = output.reshape(output_size_reshaped);
Tensor input_reshaped = input;
if(dim!= 2){
output = output.reshape(output_reshaped_size);
input_reshaped = input_reshaped.reshape(input_reshaped_size);
}
auto bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
@ -56,24 +62,20 @@ Tensor linear_pointwise_binary(
std::string_view binary_attr) {
const OptionalDeviceGuard device_guard(device_of(input_t));
onednn::Attr attr;
attr = construct_binary_attr<true>(binary_attr, /*alpha*/ 1.f, other_t, attr);
attr = construct_binary_attr<true>(attr, binary_attr, other_t);
auto input = input_t.contiguous();
auto input_size = input.sizes();
const int64_t dim = input.dim();
// dim collapse
auto input_reshaped =
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
output_size.push_back(weight_t.size(0));
auto [input_reshaped_size, output_size, output_reshaped_size] = collapse_in_out_dim(input, dim, weight_t);
Tensor output = at::empty(output_size, input.options());
Tensor input_reshaped = input;
if (dim != 2) {
// input [m, k], weight [n, k], output [m, n]
std::vector<int64_t> output_size_reshaped = {
input_reshaped.size(0), weight_t.size(0)};
output = output.reshape(output_size_reshaped);
output = output.reshape(output_reshaped_size);
input_reshaped = input_reshaped.reshape(input_reshaped_size);
} else {
TORCH_CHECK(
output.dim() == other_t.dim(),