mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -4,12 +4,12 @@ using namespace at::native::onednn;
|
||||
|
||||
namespace at::native::xpu {
|
||||
|
||||
onednn::Attr unary_attr_with_arg(
|
||||
onednn::Attr& unary_attr_with_arg(
|
||||
onednn::Attr& attr,
|
||||
std::string_view unary,
|
||||
torch::List<std::optional<at::Scalar>> scalars =
|
||||
torch::List<std::optional<at::Scalar>>(),
|
||||
std::optional<std::string_view> algorithm = std::nullopt,
|
||||
onednn::Attr attr = Attr()) {
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<std::string_view> algorithm
|
||||
) {
|
||||
if (unary == "hardswish") {
|
||||
return attr.append_post_eltwise(
|
||||
1.0f, 1.f / 6.f, 1.f / 2.f, attr.kind_with_hardswish);
|
||||
@ -23,6 +23,7 @@ onednn::Attr unary_attr_with_arg(
|
||||
auto beta = scalars[1].get().toOptional<at::Scalar>().value().to<float>();
|
||||
return attr.append_post_eltwise(1.0f, alpha, beta, attr.kind_with_clip);
|
||||
} else if (unary == "gelu") {
|
||||
TORCH_CHECK(algorithm.has_value(), "GELU algorithm is not specified");
|
||||
enum dnnl::algorithm gelu_type;
|
||||
if (algorithm.value() == "none") {
|
||||
gelu_type = attr.kind_with_gelu_erf;
|
||||
@ -42,7 +43,7 @@ onednn::Attr unary_attr_with_arg(
|
||||
return attr;
|
||||
}
|
||||
|
||||
onednn::Attr string_to_unary_attr(std::string_view unary, onednn::Attr attr) {
|
||||
onednn::Attr& string_to_unary_attr(onednn::Attr& attr, std::string_view unary) {
|
||||
if (unary == "relu") {
|
||||
return attr.append_post_eltwise(1.0f, 0.0f, 0.0f, attr.kind_with_relu);
|
||||
} else if (unary == "sigmoid") {
|
||||
@ -50,29 +51,25 @@ onednn::Attr string_to_unary_attr(std::string_view unary, onednn::Attr attr) {
|
||||
} else if (unary == "tanh") {
|
||||
return attr.append_post_eltwise(1.0f, 0.0f, 0.0f, attr.kind_with_tanh);
|
||||
} else if (unary == "hardswish") {
|
||||
return unary_attr_with_arg(
|
||||
"hardswish",
|
||||
torch::List<std::optional<at::Scalar>>(),
|
||||
std::nullopt,
|
||||
attr);
|
||||
return unary_attr_with_arg(attr, "hardswish", torch::List<std::optional<at::Scalar>>(), std::nullopt);
|
||||
} else if (unary == "swish") {
|
||||
return unary_attr_with_arg(
|
||||
"silu", torch::List<std::optional<at::Scalar>>(), std::nullopt, attr);
|
||||
return unary_attr_with_arg(attr, "silu", torch::List<std::optional<at::Scalar>>(), std::nullopt);
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
onednn::Attr construct_unary_attr(
|
||||
onednn::Attr& construct_unary_attr(
|
||||
onednn::Attr& attr,
|
||||
std::string_view unary,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<std::string_view> algorithm,
|
||||
Attr attr = Attr()) {
|
||||
std::set<std::string_view> simple_unary = {
|
||||
std::optional<std::string_view> algorithm
|
||||
) {
|
||||
static const std::set<std::string_view> simple_unary = {
|
||||
"relu", "sigmoid", "tanh", "hardswish", "swish"};
|
||||
if (simple_unary.find(unary) != simple_unary.end()) {
|
||||
return string_to_unary_attr(unary, attr);
|
||||
return string_to_unary_attr(attr, unary);
|
||||
} else {
|
||||
return unary_attr_with_arg(unary, scalars, algorithm, attr);
|
||||
return unary_attr_with_arg(attr, unary, scalars, algorithm);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,28 +9,27 @@
|
||||
//
|
||||
|
||||
namespace at::native::xpu {
|
||||
at::native::onednn::Attr unary_attr_with_arg(
|
||||
at::native::onednn::Attr& unary_attr_with_arg(
|
||||
onednn::Attr& attr,
|
||||
std::string_view unary,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<std::string_view> algorithm,
|
||||
onednn::Attr attr);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
at::native::onednn::Attr string_to_unary_attr(
|
||||
std::string_view unary,
|
||||
onednn::Attr attr);
|
||||
at::native::onednn::Attr& string_to_unary_attr(
|
||||
onednn::Attr& attr,
|
||||
std::string_view unary);
|
||||
|
||||
at::native::onednn::Attr construct_unary_attr(
|
||||
at::native::onednn::Attr& construct_unary_attr(
|
||||
onednn::Attr& attr,
|
||||
std::string_view unary,
|
||||
torch::List<std::optional<at::Scalar>> scalars,
|
||||
std::optional<std::string_view> algorithm,
|
||||
onednn::Attr attr);
|
||||
std::optional<std::string_view> algorithm);
|
||||
|
||||
template <bool is_matmul = false>
|
||||
onednn::Attr construct_binary_attr(
|
||||
onednn::Attr& construct_binary_attr(
|
||||
onednn::Attr& attr,
|
||||
std::string_view binary,
|
||||
std::optional<at::Scalar> alpha,
|
||||
const Tensor& other,
|
||||
onednn::Attr attr) {
|
||||
const Tensor& other) {
|
||||
if (binary == "mul") {
|
||||
attr.append_post_binary<is_matmul>(attr.kind_with_binary_mul, other);
|
||||
} else if (binary == "sub") {
|
||||
|
@ -5,6 +5,22 @@
|
||||
|
||||
namespace at::native::xpu {
|
||||
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> collapse_in_out_dim(at::Tensor input, int64_t dim, at::Tensor weight){
|
||||
|
||||
// dim collapse, e.g. [B, M, K] -> [BM, K]
|
||||
std::vector<int64_t> input_reshaped_size = (dim == 2) ? std::vector<int64_t>(input.size(0), input.size(1)) : std::vector<int64_t>{input.numel()/(input.size(input.dim()-1)), input.size(input.dim() -1)};
|
||||
// [B, M, K] -> [B, M]
|
||||
std::vector<int64_t> output_size(input.sizes().begin(), input.sizes().end() -1);
|
||||
// [B, M, N]
|
||||
output_size.push_back(weight.size(0));
|
||||
|
||||
// [BM, N]
|
||||
std::vector<int64_t> output_reshaped_size{input_reshaped_size[0], weight.size(0)};
|
||||
return {input_reshaped_size, output_size, output_reshaped_size};
|
||||
|
||||
|
||||
}
|
||||
|
||||
Tensor linear_pointwise(
|
||||
const Tensor& input_t, // [M, K] or [B, M, K]
|
||||
const Tensor& weight_t, // [N, K]
|
||||
@ -14,27 +30,17 @@ Tensor linear_pointwise(
|
||||
std::optional<std::string_view> algorithm) {
|
||||
onednn::Attr att;
|
||||
const OptionalDeviceGuard device_guard(device_of(input_t));
|
||||
att = construct_unary_attr(attr, scalars, algorithm, att);
|
||||
att = construct_unary_attr(att, attr, scalars, algorithm);
|
||||
auto input = input_t.contiguous();
|
||||
|
||||
auto input_size = input.sizes();
|
||||
const int64_t dim = input.dim();
|
||||
|
||||
// dim collapse
|
||||
// [B, M, K] -> [BM, K]
|
||||
auto input_reshaped =
|
||||
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
|
||||
// [B, M]
|
||||
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
|
||||
// [BM, N]
|
||||
output_size.push_back(weight_t.size(0));
|
||||
|
||||
auto [input_reshaped_size, output_size, output_reshaped_size] = collapse_in_out_dim(input, dim, weight_t);
|
||||
Tensor output = at::empty(output_size, input.options());
|
||||
if (dim != 2) {
|
||||
// collapse output
|
||||
std::vector<int64_t> output_size_reshaped = {
|
||||
input_reshaped.size(0), weight_t.size(0)};
|
||||
output = output.reshape(output_size_reshaped);
|
||||
Tensor input_reshaped = input;
|
||||
if(dim!= 2){
|
||||
output = output.reshape(output_reshaped_size);
|
||||
input_reshaped = input_reshaped.reshape(input_reshaped_size);
|
||||
}
|
||||
|
||||
auto bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
|
||||
@ -56,24 +62,20 @@ Tensor linear_pointwise_binary(
|
||||
std::string_view binary_attr) {
|
||||
const OptionalDeviceGuard device_guard(device_of(input_t));
|
||||
onednn::Attr attr;
|
||||
attr = construct_binary_attr<true>(binary_attr, /*alpha*/ 1.f, other_t, attr);
|
||||
attr = construct_binary_attr<true>(attr, binary_attr, other_t);
|
||||
auto input = input_t.contiguous();
|
||||
|
||||
auto input_size = input.sizes();
|
||||
const int64_t dim = input.dim();
|
||||
|
||||
// dim collapse
|
||||
auto input_reshaped =
|
||||
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
|
||||
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
|
||||
output_size.push_back(weight_t.size(0));
|
||||
|
||||
auto [input_reshaped_size, output_size, output_reshaped_size] = collapse_in_out_dim(input, dim, weight_t);
|
||||
Tensor output = at::empty(output_size, input.options());
|
||||
Tensor input_reshaped = input;
|
||||
|
||||
if (dim != 2) {
|
||||
// input [m, k], weight [n, k], output [m, n]
|
||||
std::vector<int64_t> output_size_reshaped = {
|
||||
input_reshaped.size(0), weight_t.size(0)};
|
||||
output = output.reshape(output_size_reshaped);
|
||||
output = output.reshape(output_reshaped_size);
|
||||
input_reshaped = input_reshaped.reshape(input_reshaped_size);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
output.dim() == other_t.dim(),
|
||||
|
Reference in New Issue
Block a user