Files
pytorch/benchmarks/tensorexpr/rnn_eltwise.py
Aaron Gokaslan 1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00

125 lines
3.3 KiB
Python

import torch
from . import benchmark
class RNNEltwise(benchmark.Benchmark):
def __init__(self, mode, device, dtype, b, hs):
super().__init__(mode, device, dtype)
self.b = b
self.hs = hs
self.input = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.hx = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.cx = self.rand(
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_ih = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_hh = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
def forward(self, input, hx, cx, b_ih, b_hh):
gates = input + hx + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
def config(self):
return [self.b, self.hs]
@staticmethod
def module():
return "rnn_eltwise"
def memory_workload(self):
def memsize(t):
return t.numel() * t.element_size()
input_size = sum(memsize(t) for t in self.inputs)
output_size = 2 * memsize(self.cx)
io_size = input_size + output_size
return {"sol": io_size, "algorithmic": io_size}
@staticmethod
def default_configs():
return [[64, 512]]
benchmark.register_benchmark_class(RNNEltwise)
class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
def __init__(self, mode, device, dtype, b, hs):
benchmark.DynamicShape.__init__(self)
RNNEltwise.__init__(self, mode, device, dtype, b, hs)
def instantiate_input(self):
b, hs = self.rand_shape([self.b, self.hs])
self.input = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.hx = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.cx = self.rand(
[b, hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.b_ih = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.b_hh = self.rand(
[b, 4 * hs],
device=self.device,
dtype=self.dtype,
requires_grad=self.requires_grad,
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
@staticmethod
def module():
return "dynamic_lstm"
benchmark.register_benchmark_class(DynamicLSTM)