mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] additional support for CallMethod with autocasting (#67925)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67925 Previously, the following would always fail, because autocasting would not be enabled in the called method: ``` torch.jit.script def fn(x, y): with autocast(): # CallMethod() to some method fn(x, y) ``` This allows the above, if autocasting is globally enabled, e.g. ``` torch.jit.script def fn(x, y): with autocast(): # CallMethod() to some method with autocast(): fn(x, y) # now ``` ghstack-source-id: 142667351 Test Plan: added test in test_jit_autocast.py Reviewed By: navahgar Differential Revision: D32214439 fbshipit-source-id: bb7db054e25e18f5e3d2fdb449c35b5942ab303e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f57c63032e
commit
2e523ed229
@ -623,5 +623,42 @@ class TestAutocast(JitTestCase):
|
||||
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
|
||||
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
|
||||
|
||||
if __name__ == '__main__':
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_jit_call_method_under_autocast(self):
|
||||
@torch.jit.interface
|
||||
class Iface(torch.nn.Module):
|
||||
def forward(self, x, y) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
class Impl(Iface):
|
||||
def forward(self, x, y):
|
||||
return torch.mm(x, y)
|
||||
|
||||
class Thing1(torch.nn.Module):
|
||||
impl: Iface
|
||||
|
||||
def forward(self, x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
a = torch.mm(x, y)
|
||||
b = self.impl.forward(a, x)
|
||||
return b
|
||||
|
||||
scripted_impl = torch.jit.script(Impl())
|
||||
thing1 = Thing1()
|
||||
thing1.impl = scripted_impl
|
||||
scripted_thing1 = torch.jit.script(thing1)
|
||||
x = torch.rand([2, 2])
|
||||
y = torch.rand([2, 2])
|
||||
|
||||
# make sure this doesn't throw an error
|
||||
with torch.cuda.amp.autocast():
|
||||
ans = scripted_thing1.forward(x, y)
|
||||
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
|
||||
|
||||
# sanity check: this isn't supported currently when global autocasting
|
||||
# isn't enabled
|
||||
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -242,6 +242,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
|
||||
switch (node->kind()) {
|
||||
case prim::CallFunction:
|
||||
// TODO: limit it only to amp related node;
|
||||
if (current_state() == initial_state) {
|
||||
// if the current autocasting state is the same as the global state,
|
||||
// then autocasting will be done correctly on subsequent method and
|
||||
// function calls
|
||||
if (current_state()) {
|
||||
castTensorInputs(
|
||||
node, aten::_autocast_to_full_precision, current_state());
|
||||
}
|
||||
break;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!incompatible_amp.has_value() || incompatible_amp.value(),
|
||||
"Calls are not expected with AMP & JIT");
|
||||
@ -250,6 +260,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
|
||||
|
||||
case prim::CallMethod:
|
||||
// TODO: limit it only to amp related node;
|
||||
if (current_state() == initial_state) {
|
||||
// if the current autocasting state is the same as the global state,
|
||||
// then autocasting will be done correctly on subsequent method and
|
||||
// function calls
|
||||
if (current_state()) {
|
||||
castTensorInputs(
|
||||
node, aten::_autocast_to_full_precision, current_state());
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
|
||||
const auto& name = node->s(attr::name);
|
||||
const auto& function = class_type->getMethod(name);
|
||||
|
Reference in New Issue
Block a user