[JIT] additional support for CallMethod with autocasting (#67925)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67925

Previously, the following would always fail, because autocasting would not be enabled in the called method:

```
torch.jit.script
def fn(x, y):
    with autocast():
        # CallMethod() to some method

fn(x, y)
```

This allows the above, if autocasting is globally enabled, e.g.

```
torch.jit.script
def fn(x, y):
    with autocast():
        # CallMethod() to some method

with autocast():
    fn(x, y) # now
```
ghstack-source-id: 142667351

Test Plan: added test in test_jit_autocast.py

Reviewed By: navahgar

Differential Revision: D32214439

fbshipit-source-id: bb7db054e25e18f5e3d2fdb449c35b5942ab303e
This commit is contained in:
David Berard
2021-11-08 14:31:36 -08:00
committed by Facebook GitHub Bot
parent f57c63032e
commit 2e523ed229
2 changed files with 58 additions and 1 deletions

View File

@ -623,5 +623,42 @@ class TestAutocast(JitTestCase):
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
if __name__ == '__main__':
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_call_method_under_autocast(self):
@torch.jit.interface
class Iface(torch.nn.Module):
def forward(self, x, y) -> torch.Tensor:
pass
class Impl(Iface):
def forward(self, x, y):
return torch.mm(x, y)
class Thing1(torch.nn.Module):
impl: Iface
def forward(self, x, y):
with torch.cuda.amp.autocast():
a = torch.mm(x, y)
b = self.impl.forward(a, x)
return b
scripted_impl = torch.jit.script(Impl())
thing1 = Thing1()
thing1.impl = scripted_impl
scripted_thing1 = torch.jit.script(thing1)
x = torch.rand([2, 2])
y = torch.rand([2, 2])
# make sure this doesn't throw an error
with torch.cuda.amp.autocast():
ans = scripted_thing1.forward(x, y)
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
# sanity check: this isn't supported currently when global autocasting
# isn't enabled
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
if __name__ == "__main__":
run_tests()

View File

@ -242,6 +242,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
switch (node->kind()) {
case prim::CallFunction:
// TODO: limit it only to amp related node;
if (current_state() == initial_state) {
// if the current autocasting state is the same as the global state,
// then autocasting will be done correctly on subsequent method and
// function calls
if (current_state()) {
castTensorInputs(
node, aten::_autocast_to_full_precision, current_state());
}
break;
}
TORCH_INTERNAL_ASSERT(
!incompatible_amp.has_value() || incompatible_amp.value(),
"Calls are not expected with AMP & JIT");
@ -250,6 +260,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
case prim::CallMethod:
// TODO: limit it only to amp related node;
if (current_state() == initial_state) {
// if the current autocasting state is the same as the global state,
// then autocasting will be done correctly on subsequent method and
// function calls
if (current_state()) {
castTensorInputs(
node, aten::_autocast_to_full_precision, current_state());
}
break;
}
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
const auto& name = node->s(attr::name);
const auto& function = class_type->getMethod(name);