[SR] Remove unused operator() overload (#67001)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67001

The overload of `operator()` taking `std::vector<at::Tensor>` was only used for testing. In a diff following this one, I will add a new overload that takes `std::vector<c10::IValue> args` and no `kwargs` so we can avoid default-constructing `kwargs` everywhere.

This new overload will probably take a forwarding reference, so to avoid problems with overloading on forwarding reference and simplify the interface, it's best to remove this unused one.

Test Plan:
`buck test caffe2/benchmarks/static_runtime/...`

`buck test caffe2/test:static_runtime`

Reviewed By: hlu1

Differential Revision: D31821990

fbshipit-source-id: 6d2e4a75ca4abe6e262651532eb96c3b274c6f4a
This commit is contained in:
Mike Iovine
2021-10-25 08:16:14 -07:00
committed by Facebook GitHub Bot
parent 364645cd9d
commit a0495b3cdb
6 changed files with 47 additions and 85 deletions

View File

@ -17,7 +17,7 @@ class StaticModule:
def __call__(self, *args, **kwargs):
if not kwargs:
return self.static_module(args)
return self.static_module(args, {})
else:
return self.static_module(args, kwargs)
@ -227,20 +227,20 @@ class TestStaticModule(TestCase):
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
acc_bot = bot_l_acc(bot_inp)
torch.testing.assert_close(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
acc_top = top_l_acc(top_inp)
torch.testing.assert_close(acc_top, ref_top)
for _ in range(5):
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
acc_bot = bot_l_acc(bot_inp)
torch.testing.assert_close(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
acc_top = top_l_acc(top_inp)
torch.testing.assert_close(acc_top, ref_top)
def test_trivial_graph(self):
@ -248,7 +248,7 @@ class TestStaticModule(TestCase):
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticModule(tg)
o_test = tg_a(s, s, s)[0]
o_test = tg_a(s, s, s)
torch.testing.assert_close(o_ref, o_test)
def test_leaky_relu(self):
@ -256,7 +256,7 @@ class TestStaticModule(TestCase):
tg = torch.jit.script(nn.LeakyReLU(0.1))
o_ref = tg(s)
tg_a = StaticModule(tg)
o_test = tg_a(s)[0]
o_test = tg_a(s)
torch.testing.assert_close(o_ref, o_test)
def test_attr(self):
@ -292,7 +292,7 @@ class TestStaticModule(TestCase):
ms = torch.jit.script(m)
sm = StaticModule(ms)
output_sm = sm(input)[0]
output_sm = sm(input)
torch.testing.assert_close(output_s, output_sm)
sm.benchmark([input], {}, 2, 2)
sm.benchmark_individual_ops([input], {}, 2, 2)