mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied. Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960 Approved by: https://github.com/malfet
125 lines
3.3 KiB
Python
125 lines
3.3 KiB
Python
import torch
|
|
|
|
from . import benchmark
|
|
|
|
|
|
class RNNEltwise(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, b, hs):
|
|
super().__init__(mode, device, dtype)
|
|
self.b = b
|
|
self.hs = hs
|
|
self.input = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.hx = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.cx = self.rand(
|
|
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.b_ih = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.b_hh = self.rand(
|
|
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [
|
|
self.input,
|
|
self.hx,
|
|
self.cx,
|
|
self.b_ih,
|
|
self.b_hh,
|
|
]
|
|
|
|
def forward(self, input, hx, cx, b_ih, b_hh):
|
|
gates = input + hx + b_ih + b_hh
|
|
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
|
|
ingate = torch.sigmoid(ingate)
|
|
forgetgate = torch.sigmoid(forgetgate)
|
|
cellgate = torch.tanh(cellgate)
|
|
outgate = torch.sigmoid(outgate)
|
|
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * torch.tanh(cy)
|
|
|
|
return hy, cy
|
|
|
|
def config(self):
|
|
return [self.b, self.hs]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "rnn_eltwise"
|
|
|
|
def memory_workload(self):
|
|
def memsize(t):
|
|
return t.numel() * t.element_size()
|
|
|
|
input_size = sum(memsize(t) for t in self.inputs)
|
|
output_size = 2 * memsize(self.cx)
|
|
io_size = input_size + output_size
|
|
return {"sol": io_size, "algorithmic": io_size}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[64, 512]]
|
|
|
|
|
|
benchmark.register_benchmark_class(RNNEltwise)
|
|
|
|
|
|
class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
|
|
def __init__(self, mode, device, dtype, b, hs):
|
|
benchmark.DynamicShape.__init__(self)
|
|
RNNEltwise.__init__(self, mode, device, dtype, b, hs)
|
|
|
|
def instantiate_input(self):
|
|
b, hs = self.rand_shape([self.b, self.hs])
|
|
|
|
self.input = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.hx = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.cx = self.rand(
|
|
[b, hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.b_ih = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.b_hh = self.rand(
|
|
[b, 4 * hs],
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=self.requires_grad,
|
|
)
|
|
self.inputs = [
|
|
self.input,
|
|
self.hx,
|
|
self.cx,
|
|
self.b_ih,
|
|
self.b_hh,
|
|
]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "dynamic_lstm"
|
|
|
|
|
|
benchmark.register_benchmark_class(DynamicLSTM)
|