mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SR] Remove unused operator() overload (#67001)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67001 The overload of `operator()` taking `std::vector<at::Tensor>` was only used for testing. In a diff following this one, I will add a new overload that takes `std::vector<c10::IValue> args` and no `kwargs` so we can avoid default-constructing `kwargs` everywhere. This new overload will probably take a forwarding reference, so to avoid problems with overloading on forwarding reference and simplify the interface, it's best to remove this unused one. Test Plan: `buck test caffe2/benchmarks/static_runtime/...` `buck test caffe2/test:static_runtime` Reviewed By: hlu1 Differential Revision: D31821990 fbshipit-source-id: 6d2e4a75ca4abe6e262651532eb96c3b274c6f4a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
364645cd9d
commit
a0495b3cdb
@ -17,7 +17,7 @@ class StaticModule:
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not kwargs:
|
||||
return self.static_module(args)
|
||||
return self.static_module(args, {})
|
||||
else:
|
||||
return self.static_module(args, kwargs)
|
||||
|
||||
@ -227,20 +227,20 @@ class TestStaticModule(TestCase):
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
acc_bot = bot_l_acc(bot_inp)
|
||||
torch.testing.assert_close(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
acc_top = top_l_acc(top_inp)
|
||||
torch.testing.assert_close(acc_top, ref_top)
|
||||
for _ in range(5):
|
||||
with torch.no_grad():
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
acc_bot = bot_l_acc(bot_inp)
|
||||
torch.testing.assert_close(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
acc_top = top_l_acc(top_inp)
|
||||
torch.testing.assert_close(acc_top, ref_top)
|
||||
|
||||
def test_trivial_graph(self):
|
||||
@ -248,7 +248,7 @@ class TestStaticModule(TestCase):
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
tg_a = StaticModule(tg)
|
||||
o_test = tg_a(s, s, s)[0]
|
||||
o_test = tg_a(s, s, s)
|
||||
torch.testing.assert_close(o_ref, o_test)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
@ -256,7 +256,7 @@ class TestStaticModule(TestCase):
|
||||
tg = torch.jit.script(nn.LeakyReLU(0.1))
|
||||
o_ref = tg(s)
|
||||
tg_a = StaticModule(tg)
|
||||
o_test = tg_a(s)[0]
|
||||
o_test = tg_a(s)
|
||||
torch.testing.assert_close(o_ref, o_test)
|
||||
|
||||
def test_attr(self):
|
||||
@ -292,7 +292,7 @@ class TestStaticModule(TestCase):
|
||||
|
||||
ms = torch.jit.script(m)
|
||||
sm = StaticModule(ms)
|
||||
output_sm = sm(input)[0]
|
||||
output_sm = sm(input)
|
||||
torch.testing.assert_close(output_s, output_sm)
|
||||
sm.benchmark([input], {}, 2, 2)
|
||||
sm.benchmark_individual_ops([input], {}, 2, 2)
|
||||
|
Reference in New Issue
Block a user