mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: The profiling, even when disabled, takes up about 1.5% cpu for a model I'm looking into. This patch just splits into with/without profile runs. The potential downside is that now the script can't enable profiling in itself. It doesn't seem to be used anywhere. If that's a crusial usecase, we can do something about it but ideally we wouldn't. Test Plan: Link with profiles: https://fburl.com/scuba/strobelight_services/ihxsl7pj ``` buck2 run fbcode//caffe2/test/cpp/jit:jit ``` Reviewed By: zhxchen17 Differential Revision: D54066589 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121404 Approved by: https://github.com/zhxchen17
112 lines
2.9 KiB
Python
112 lines
2.9 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class Sequence(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lstm1 = nn.LSTMCell(1, 51)
|
|
self.lstm2 = nn.LSTMCell(51, 51)
|
|
self.linear = nn.Linear(51, 1)
|
|
|
|
def forward(self, input):
|
|
outputs = []
|
|
h_t = torch.zeros(input.size(0), 51)
|
|
c_t = torch.zeros(input.size(0), 51)
|
|
h_t2 = torch.zeros(input.size(0), 51)
|
|
c_t2 = torch.zeros(input.size(0), 51)
|
|
|
|
for input_t in input.split(1, dim=1):
|
|
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
|
|
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
|
|
output = self.linear(h_t2)
|
|
outputs += [output]
|
|
outputs = torch.cat(outputs, dim=1)
|
|
return outputs
|
|
|
|
class TestScriptProfile(JitTestCase):
|
|
|
|
def test_basic(self):
|
|
seq = torch.jit.script(Sequence())
|
|
p = torch.jit._ScriptProfile()
|
|
p.enable()
|
|
seq(torch.rand((10, 100)))
|
|
p.disable()
|
|
self.assertNotEqual(p.dump_string(), "")
|
|
|
|
def test_script(self):
|
|
seq = Sequence()
|
|
|
|
p = torch.jit._ScriptProfile()
|
|
p.enable()
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
_ = seq(torch.rand((10, 100)))
|
|
fn()
|
|
p.disable()
|
|
|
|
self.assertNotEqual(p.dump_string(), "")
|
|
|
|
def test_multi(self):
|
|
seq = torch.jit.script(Sequence())
|
|
profiles = [torch.jit._ScriptProfile() for _ in range(5)]
|
|
for p in profiles:
|
|
p.enable()
|
|
|
|
last = None
|
|
while len(profiles) > 0:
|
|
seq(torch.rand((10, 10)))
|
|
p = profiles.pop()
|
|
p.disable()
|
|
stats = p.dump_string()
|
|
self.assertNotEqual(stats, "")
|
|
if last:
|
|
self.assertNotEqual(stats, last)
|
|
last = stats
|
|
|
|
def test_section(self):
|
|
seq = Sequence()
|
|
|
|
@torch.jit.script
|
|
def fn(max : int):
|
|
_ = seq(torch.rand((10, max)))
|
|
|
|
p = torch.jit._ScriptProfile()
|
|
p.enable()
|
|
fn(100)
|
|
p.disable()
|
|
s0 = p.dump_string()
|
|
|
|
fn(10)
|
|
p.disable()
|
|
s1 = p.dump_string()
|
|
|
|
p.enable()
|
|
fn(10)
|
|
p.disable()
|
|
s2 = p.dump_string()
|
|
|
|
self.assertEqual(s0, s1)
|
|
self.assertNotEqual(s1, s2)
|
|
|
|
def test_empty(self):
|
|
p = torch.jit._ScriptProfile()
|
|
p.enable()
|
|
p.disable()
|
|
self.assertEqual(p.dump_string(), "")
|