mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152237 Approved by: https://github.com/huydhn, https://github.com/malfet
92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded
|
|
# for benchmarking and some control flow stripped out.
|
|
# https://github.com/mlcommons/training/blob/master/retired_benchmarks/gnmt/pytorch/seq2seq/models/attention.py
|
|
|
|
import torch
|
|
|
|
from . import benchmark
|
|
|
|
|
|
class BahdanauAttention(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, b, t_q, t_k, n):
|
|
super().__init__(mode, device, dtype)
|
|
self.b = b
|
|
self.t_q = t_q
|
|
self.t_k = t_k
|
|
self.n = n
|
|
self.att_query = self.rand(
|
|
[b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.att_keys = self.rand(
|
|
[b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.normalize_bias = self.rand(
|
|
[n], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.linear_att = self.rand(
|
|
[n], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [
|
|
self.att_query,
|
|
self.att_keys,
|
|
self.normalize_bias,
|
|
self.linear_att,
|
|
]
|
|
|
|
def forward(self, att_query, att_keys, normalize_bias, linear_att):
|
|
"""
|
|
Calculate Bahdanau score
|
|
|
|
:param att_query: b x t_q x n
|
|
:param att_keys: b x t_k x n
|
|
|
|
return b x t_q x t_k scores
|
|
"""
|
|
|
|
b, t_k, n = att_keys.size()
|
|
t_q = att_query.size(1)
|
|
|
|
att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
|
|
att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
|
|
sum_qk = att_query + att_keys + normalize_bias
|
|
out = torch.tanh(sum_qk).matmul(linear_att)
|
|
return out
|
|
|
|
def reference(self):
|
|
return self.numpy(self.forward(*self.inputs))
|
|
|
|
def config(self):
|
|
return [self.b, self.t_q, self.t_k, self.n]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "attention"
|
|
|
|
def memory_workload(self):
|
|
def memsize(t):
|
|
return t.numel() * t.element_size()
|
|
|
|
input_size = (
|
|
memsize(self.att_query)
|
|
+ memsize(self.att_keys)
|
|
+ memsize(self.normalize_bias)
|
|
+ memsize(self.linear_att)
|
|
)
|
|
output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel()
|
|
io_size = input_size + output_size
|
|
|
|
# If matmul is not fused, must write and then read `sum_qk`.
|
|
intermediate_size = (
|
|
2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel()
|
|
)
|
|
return {"sol": io_size, "algorithmic": io_size + intermediate_size}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
mlperf_inference = [1280, 1, 66, 1024]
|
|
nvidia = [128, 10, 128, 1024]
|
|
return [mlperf_inference, nvidia]
|
|
|
|
|
|
benchmark.register_benchmark_class(BahdanauAttention)
|