mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54450 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D27243175 Pulled By: huiguoo fbshipit-source-id: 820cf0d6cd1dd984d4153628e0f419d234668c82
124 lines
4.6 KiB
Python
124 lines
4.6 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
import unittest
|
|
|
|
LLVM_ENABLED = torch._C._llvm_enabled()
|
|
|
|
class kernel_arena_scope(object):
|
|
def __enter__(self):
|
|
self.scope = torch._C._te.KernelScope()
|
|
|
|
def __exit__(self, typ, val, traceback):
|
|
self.scope = None
|
|
|
|
class TestTensorExprPyBind(JitTestCase):
|
|
def test_simple_sum(self):
|
|
with kernel_arena_scope():
|
|
dtype = torch._C._te.Dtype.Float
|
|
N = 32
|
|
dN = torch._C._te.ExprHandle.int(N)
|
|
|
|
A = torch._C._te.Placeholder('A', dtype, [dN])
|
|
B = torch._C._te.Placeholder('B', dtype, [dN])
|
|
|
|
def compute(i):
|
|
return A.load([i]) + B.load([i])
|
|
C = torch._C._te.Compute('C', [torch._C._te.DimArg(dN, 'i')], compute)
|
|
|
|
loopnest = torch._C._te.LoopNest([C])
|
|
loopnest.prepare_for_codegen()
|
|
stmt = torch._C._te.simplify(loopnest.root_stmt())
|
|
|
|
cg = torch._C._te.construct_codegen('ir_eval', stmt, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
|
|
|
tA = torch.rand(N) * 5
|
|
tB = torch.rand(N) * 6
|
|
tC = torch.empty(N)
|
|
cg.call([tA, tB, tC])
|
|
torch.testing.assert_allclose(tA + tB, tC)
|
|
|
|
def test_external_calls(self):
|
|
with kernel_arena_scope():
|
|
dtype = torch._C._te.Dtype.Float
|
|
|
|
ZERO = torch._C._te.ExprHandle.int(0)
|
|
ONE = torch._C._te.ExprHandle.int(1)
|
|
FOUR = torch._C._te.ExprHandle.int(4)
|
|
A = torch._C._te.BufHandle('A', [ONE, FOUR], dtype)
|
|
B = torch._C._te.BufHandle('B', [FOUR, ONE], dtype)
|
|
C = torch._C._te.BufHandle('C', [ONE, ONE], dtype)
|
|
|
|
s = torch._C._te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
|
|
|
|
loopnest = torch._C._te.LoopNest(s, [C])
|
|
loopnest.prepare_for_codegen()
|
|
codegen = torch._C._te.construct_codegen('ir_eval', s, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
|
|
|
tA = torch.ones(1, 4)
|
|
tB = torch.ones(4, 1)
|
|
tC = torch.empty(1, 1)
|
|
codegen.call([tA, tB, tC])
|
|
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
|
|
|
|
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
|
def test_kernel_with_tensor_inputs(self):
|
|
def f(a, b, c):
|
|
return a + b + c
|
|
device, size = 'cpu', (4, 4)
|
|
x = torch.rand(size, device=device)
|
|
y = torch.rand(size, device=device)
|
|
z = torch.rand(size, device=device)
|
|
|
|
graph_str = """
|
|
graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
|
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
|
%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
|
|
%6 : int = prim::Constant[value=1]()
|
|
%7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
|
|
%3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
|
|
return (%3)
|
|
"""
|
|
graph = torch._C.parse_ir(graph_str)
|
|
|
|
with kernel_arena_scope():
|
|
kernel = torch._C._te.TensorExprKernel(graph)
|
|
res1 = kernel.run((x, y, z))
|
|
res2 = kernel.fallback((x, y, z))
|
|
correct = f(x, y, z)
|
|
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
|
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
|
|
|
|
|
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
|
def test_kernel_with_scalar_inputs(self):
|
|
def f(a, b, c):
|
|
return a + b + c
|
|
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
|
|
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
|
|
z = torch.tensor(0.7, dtype=torch.float, device='cpu')
|
|
|
|
graph_str = """
|
|
graph(%a.1 : Float(requires_grad=0, device=cpu),
|
|
%b.1 : Float(requires_grad=0, device=cpu),
|
|
%c.1 : Float(requires_grad=0, device=cpu)):
|
|
%3 : int = prim::Constant[value=1]()
|
|
%6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
|
|
%9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
|
|
return (%9)
|
|
"""
|
|
graph = torch._C.parse_ir(graph_str)
|
|
|
|
with kernel_arena_scope():
|
|
kernel = torch._C._te.TensorExprKernel(graph)
|
|
res1 = kernel.run((x, y, z))
|
|
res2 = kernel.fallback((x, y, z))
|
|
correct = f(x, y, z)
|
|
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
|
|
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|