mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize to if the datatyep of the source tensor is as same as the dest datatype (#85140)
The AMP inserts `_autocast_to_reduced_precision` and `_autocast_to_full_precision` automatically. The aten implementation provides a fast path to bypass the conversion if the tensor data type has been the reduced/full precision. But NNC always does the conversion which could bring >5% E2E performance regression. This PR is to address the performance issue like aten. We will not pull `_autocast_to_reduced_precision` and `_autocast_to_full_precision` into NNC fusion group and fallback to aten to trigger its fast path if the tensor data type has been the reduced/full precision. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85140 Approved by: https://github.com/frank-wei
This commit is contained in:
committed by
PyTorch MergeBot
parent
83261ff9a8
commit
45be74cc63
@ -2348,6 +2348,32 @@ class TestTEFuser(JitTestCase):
|
||||
scr(x)
|
||||
self.assertLastGraphAllFused()
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
|
||||
def test_to_dtype(self):
|
||||
def f(x):
|
||||
y = torch.sigmoid(x)
|
||||
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
|
||||
h = z._autocast_to_full_precision(True, True)
|
||||
i = h.to(dtype=torch.bfloat16)
|
||||
j = i.to(dtype=torch.float32)
|
||||
return j
|
||||
|
||||
x = torch.rand((2, 2), dtype=torch.float32)
|
||||
scr = torch.jit.trace(f, x)
|
||||
scr(x)
|
||||
scr(x)
|
||||
self.assertLastGraphAllFused()
|
||||
self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
|
||||
|
||||
bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
|
||||
bf_scr = torch.jit.trace(f, bf_x)
|
||||
bf_scr(bf_x)
|
||||
bf_scr(bf_x)
|
||||
graph = bf_scr.graph_for(bf_x)
|
||||
fusion_groups = self.findFusionGroups(graph)
|
||||
self.assertEqual(len(fusion_groups), 2)
|
||||
self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
|
||||
|
||||
def test_with_strict_fusion(self):
|
||||
|
||||
def success(x):
|
||||
|
Reference in New Issue
Block a user