mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Optimize to if the datatyep of the source tensor is as same as the dest datatype (#85140)
The AMP inserts `_autocast_to_reduced_precision` and `_autocast_to_full_precision` automatically. The aten implementation provides a fast path to bypass the conversion if the tensor data type has been the reduced/full precision. But NNC always does the conversion which could bring >5% E2E performance regression. This PR is to address the performance issue like aten. We will not pull `_autocast_to_reduced_precision` and `_autocast_to_full_precision` into NNC fusion group and fallback to aten to trigger its fast path if the tensor data type has been the reduced/full precision. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85140 Approved by: https://github.com/frank-wei
This commit is contained in:
committed by
PyTorch MergeBot
parent
83261ff9a8
commit
45be74cc63
@ -579,6 +579,54 @@ TEST_F(Kernel, CatInputTypesPromotion) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(Kernel, ToDType) {
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
|
||||
%1 : NoneType = prim::Constant()
|
||||
%2 : bool = prim::Constant[value=0]()
|
||||
%3 : int = prim::Constant[value=6]()
|
||||
%4 : int = prim::Constant[value=15]()
|
||||
%5 : int = prim::Constant[value=5]()
|
||||
%6 : bool = prim::Constant[value=1]()
|
||||
%y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1)
|
||||
%z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4)
|
||||
%h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6)
|
||||
%i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1)
|
||||
%j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1)
|
||||
%k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1)
|
||||
return (%k.3))IR";
|
||||
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
TensorExprKernel k(graph);
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
std::ostringstream oss;
|
||||
oss << *s;
|
||||
|
||||
const std::string& verification_pattern =
|
||||
R"IR(
|
||||
# CHECK: for
|
||||
# CHECK-NEXT: for
|
||||
# CHECK-NEXT: aten_to
|
||||
# CHECK-NEXT: }
|
||||
# CHECK-NEXT: })IR";
|
||||
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
|
||||
|
||||
auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16));
|
||||
auto ref =
|
||||
at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat));
|
||||
|
||||
std::vector<at::Tensor> inputs = {a};
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto o = stack[0].toTensor();
|
||||
ASSERT_EQ(o.sizes(), ref.sizes());
|
||||
ASSERT_EQ(o.dtype(), ref.dtype());
|
||||
ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(Kernel, CatAndInlineWithAConstantDim) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
|
||||
@ -915,7 +963,7 @@ TEST_F(Kernel, SumOneAxis) {
|
||||
o = stack[0].toTensor();
|
||||
ASSERT_EQ(o.sizes(), ref.sizes());
|
||||
ASSERT_EQ(o.dtype(), ref.dtype());
|
||||
ASSERT_TRUE(at::allclose(o, ref));
|
||||
ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2348,6 +2348,32 @@ class TestTEFuser(JitTestCase):
|
||||
scr(x)
|
||||
self.assertLastGraphAllFused()
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
|
||||
def test_to_dtype(self):
|
||||
def f(x):
|
||||
y = torch.sigmoid(x)
|
||||
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
|
||||
h = z._autocast_to_full_precision(True, True)
|
||||
i = h.to(dtype=torch.bfloat16)
|
||||
j = i.to(dtype=torch.float32)
|
||||
return j
|
||||
|
||||
x = torch.rand((2, 2), dtype=torch.float32)
|
||||
scr = torch.jit.trace(f, x)
|
||||
scr(x)
|
||||
scr(x)
|
||||
self.assertLastGraphAllFused()
|
||||
self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
|
||||
|
||||
bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
|
||||
bf_scr = torch.jit.trace(f, bf_x)
|
||||
bf_scr(bf_x)
|
||||
bf_scr(bf_x)
|
||||
graph = bf_scr.graph_for(bf_x)
|
||||
fusion_groups = self.findFusionGroups(graph)
|
||||
self.assertEqual(len(fusion_groups), 2)
|
||||
self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
|
||||
|
||||
def test_with_strict_fusion(self):
|
||||
|
||||
def success(x):
|
||||
|
@ -1039,6 +1039,40 @@ class TensorExprFuser {
|
||||
}
|
||||
}
|
||||
|
||||
bool is_reduced_precision =
|
||||
node->kind() == aten::_autocast_to_reduced_precision;
|
||||
bool is_full_precision =
|
||||
node->kind() == aten::_autocast_to_full_precision;
|
||||
auto self_tensor = node->inputs()[0]; // input tensor
|
||||
|
||||
if (auto const& tt = self_tensor->type()->cast<TensorType>()) {
|
||||
auto st = tt->scalarType();
|
||||
if (!st.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto device = tt->device();
|
||||
if (!device.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_cpu = device->is_cpu();
|
||||
|
||||
if (*st != at::kFloat && is_reduced_precision && is_cpu) {
|
||||
// Regarding CPU, aten would do nothing if the data type is
|
||||
// float. Then the aten performance is better than NNC. So NNC
|
||||
// does not pull it into its fusion group.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (*st != at::kBFloat16 && is_full_precision && is_cpu) {
|
||||
// Regarding CPU, aten would do nothing if the data type is
|
||||
// BFloat16. Then the aten performance is better than NNC. So NNC
|
||||
// does not pull it into its fusion group.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_unsupported_pin_memory(node)) {
|
||||
return false;
|
||||
}
|
||||
|
Reference in New Issue
Block a user