Optimize to if the datatyep of the source tensor is as same as the dest datatype (#85140)

The AMP inserts `_autocast_to_reduced_precision` and `_autocast_to_full_precision` automatically. The aten implementation provides a fast path to bypass the conversion if the tensor data type has been the reduced/full precision. But NNC always does the conversion which could bring >5% E2E performance regression.

This PR is to address the performance issue like aten. We will not pull `_autocast_to_reduced_precision` and `_autocast_to_full_precision` into NNC fusion group and fallback to aten to trigger its fast path if the tensor data type has been the reduced/full precision.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85140
Approved by: https://github.com/frank-wei
This commit is contained in:
Wang, Eikan
2022-09-26 07:50:32 +00:00
committed by PyTorch MergeBot
parent 83261ff9a8
commit 45be74cc63
3 changed files with 109 additions and 1 deletions

View File

@ -2348,6 +2348,32 @@ class TestTEFuser(JitTestCase):
scr(x)
self.assertLastGraphAllFused()
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
def test_to_dtype(self):
def f(x):
y = torch.sigmoid(x)
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
h = z._autocast_to_full_precision(True, True)
i = h.to(dtype=torch.bfloat16)
j = i.to(dtype=torch.float32)
return j
x = torch.rand((2, 2), dtype=torch.float32)
scr = torch.jit.trace(f, x)
scr(x)
scr(x)
self.assertLastGraphAllFused()
self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
bf_scr = torch.jit.trace(f, bf_x)
bf_scr(bf_x)
bf_scr(bf_x)
graph = bf_scr.graph_for(bf_x)
fusion_groups = self.findFusionGroups(graph)
self.assertEqual(len(fusion_groups), 2)
self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
def test_with_strict_fusion(self):
def success(x):