mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[JIT] additional support for CallMethod with autocasting (#67925)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67925 Previously, the following would always fail, because autocasting would not be enabled in the called method: ``` torch.jit.script def fn(x, y): with autocast(): # CallMethod() to some method fn(x, y) ``` This allows the above, if autocasting is globally enabled, e.g. ``` torch.jit.script def fn(x, y): with autocast(): # CallMethod() to some method with autocast(): fn(x, y) # now ``` ghstack-source-id: 142667351 Test Plan: added test in test_jit_autocast.py Reviewed By: navahgar Differential Revision: D32214439 fbshipit-source-id: bb7db054e25e18f5e3d2fdb449c35b5942ab303e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f57c63032e
commit
2e523ed229
@ -623,5 +623,42 @@ class TestAutocast(JitTestCase):
|
|||||||
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
|
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
|
||||||
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
|
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||||
|
def test_jit_call_method_under_autocast(self):
|
||||||
|
@torch.jit.interface
|
||||||
|
class Iface(torch.nn.Module):
|
||||||
|
def forward(self, x, y) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Impl(Iface):
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.mm(x, y)
|
||||||
|
|
||||||
|
class Thing1(torch.nn.Module):
|
||||||
|
impl: Iface
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
a = torch.mm(x, y)
|
||||||
|
b = self.impl.forward(a, x)
|
||||||
|
return b
|
||||||
|
|
||||||
|
scripted_impl = torch.jit.script(Impl())
|
||||||
|
thing1 = Thing1()
|
||||||
|
thing1.impl = scripted_impl
|
||||||
|
scripted_thing1 = torch.jit.script(thing1)
|
||||||
|
x = torch.rand([2, 2])
|
||||||
|
y = torch.rand([2, 2])
|
||||||
|
|
||||||
|
# make sure this doesn't throw an error
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
ans = scripted_thing1.forward(x, y)
|
||||||
|
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
|
||||||
|
|
||||||
|
# sanity check: this isn't supported currently when global autocasting
|
||||||
|
# isn't enabled
|
||||||
|
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -242,6 +242,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
|
|||||||
switch (node->kind()) {
|
switch (node->kind()) {
|
||||||
case prim::CallFunction:
|
case prim::CallFunction:
|
||||||
// TODO: limit it only to amp related node;
|
// TODO: limit it only to amp related node;
|
||||||
|
if (current_state() == initial_state) {
|
||||||
|
// if the current autocasting state is the same as the global state,
|
||||||
|
// then autocasting will be done correctly on subsequent method and
|
||||||
|
// function calls
|
||||||
|
if (current_state()) {
|
||||||
|
castTensorInputs(
|
||||||
|
node, aten::_autocast_to_full_precision, current_state());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
!incompatible_amp.has_value() || incompatible_amp.value(),
|
!incompatible_amp.has_value() || incompatible_amp.value(),
|
||||||
"Calls are not expected with AMP & JIT");
|
"Calls are not expected with AMP & JIT");
|
||||||
@ -250,6 +260,16 @@ void handleBlock(Block* block, AutocastContext initial_state) {
|
|||||||
|
|
||||||
case prim::CallMethod:
|
case prim::CallMethod:
|
||||||
// TODO: limit it only to amp related node;
|
// TODO: limit it only to amp related node;
|
||||||
|
if (current_state() == initial_state) {
|
||||||
|
// if the current autocasting state is the same as the global state,
|
||||||
|
// then autocasting will be done correctly on subsequent method and
|
||||||
|
// function calls
|
||||||
|
if (current_state()) {
|
||||||
|
castTensorInputs(
|
||||||
|
node, aten::_autocast_to_full_precision, current_state());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
|
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
|
||||||
const auto& name = node->s(attr::name);
|
const auto& name = node->s(attr::name);
|
||||||
const auto& function = class_type->getMethod(name);
|
const auto& function = class_type->getMethod(name);
|
||||||
|
Reference in New Issue
Block a user