[TensorExpr] Python binding improvements (#59920)

Summary:
Some minor quality of life improvements for the NNC python bindings:
- expose `call_raw()`
- support passing integers to `call()` (for dynamic shapes)
- implicit conversions to cleanup `[BufferArg(x) for x in [A, B, C]]` into just `[A, B, C]`
- don't silently default to "ir_eval" for unknown mode (e.g. "LLVM")

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59920

Reviewed By: ZolotukhinM

Differential Revision: D29090904

Pulled By: jansel

fbshipit-source-id: 154ace82725ae2046cfe2e6eb324fd37f5d209a7
This commit is contained in:
Jason Ansel
2021-06-14 09:29:34 -07:00
committed by Facebook GitHub Bot
parent 68d690ffbd
commit 3d90c82a5c
2 changed files with 106 additions and 33 deletions

View File

@ -1,5 +1,6 @@
import torch
import numpy as np
import torch._C._te as te
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
@ -7,6 +8,7 @@ import unittest
LLVM_ENABLED = torch._C._llvm_enabled()
class kernel_arena_scope(object):
def __enter__(self):
self.scope = torch._C._te.KernelScope()
@ -14,48 +16,62 @@ class kernel_arena_scope(object):
def __exit__(self, typ, val, traceback):
self.scope = None
def construct_adder(n: int, dtype=te.Dtype.Float):
dN = te.ExprHandle.int(n)
A = te.Placeholder('A', dtype, [dN])
B = te.Placeholder('B', dtype, [dN])
def compute(i):
return A.load([i]) + B.load([i])
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
loopnest = te.LoopNest([C])
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
return te.construct_codegen('ir_eval', stmt, [A, B, C])
class TestTensorExprPyBind(JitTestCase):
def test_simple_sum(self):
with kernel_arena_scope():
dtype = torch._C._te.Dtype.Float
N = 32
dN = torch._C._te.ExprHandle.int(N)
n = 32
cg = construct_adder(n)
A = torch._C._te.Placeholder('A', dtype, [dN])
B = torch._C._te.Placeholder('B', dtype, [dN])
def compute(i):
return A.load([i]) + B.load([i])
C = torch._C._te.Compute('C', [torch._C._te.DimArg(dN, 'i')], compute)
loopnest = torch._C._te.LoopNest([C])
loopnest.prepare_for_codegen()
stmt = torch._C._te.simplify(loopnest.root_stmt())
cg = torch._C._te.construct_codegen('ir_eval', stmt, [torch._C._te.BufferArg(x) for x in [A, B, C]])
tA = torch.rand(N) * 5
tB = torch.rand(N) * 6
tC = torch.empty(N)
tA = torch.randn(n)
tB = torch.randn(n)
tC = torch.empty(n)
cg.call([tA, tB, tC])
torch.testing.assert_allclose(tA + tB, tC)
def test_call_raw(self):
with kernel_arena_scope():
n = 16
cg = construct_adder(n, dtype=te.Dtype.Double)
tA = torch.randn(n, dtype=torch.float64)
tB = torch.randn(n, dtype=torch.float64)
tC = torch.empty(n, dtype=torch.float64)
cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
torch.testing.assert_allclose(tA + tB, tC)
def test_external_calls(self):
with kernel_arena_scope():
dtype = torch._C._te.Dtype.Float
dtype = te.Dtype.Float
ZERO = torch._C._te.ExprHandle.int(0)
ONE = torch._C._te.ExprHandle.int(1)
FOUR = torch._C._te.ExprHandle.int(4)
A = torch._C._te.BufHandle('A', [ONE, FOUR], dtype)
B = torch._C._te.BufHandle('B', [FOUR, ONE], dtype)
C = torch._C._te.BufHandle('C', [ONE, ONE], dtype)
ONE = te.ExprHandle.int(1)
FOUR = te.ExprHandle.int(4)
A = te.BufHandle('A', [ONE, FOUR], dtype)
B = te.BufHandle('B', [FOUR, ONE], dtype)
C = te.BufHandle('C', [ONE, ONE], dtype)
s = torch._C._te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
loopnest = torch._C._te.LoopNest(s, [C])
loopnest = te.LoopNest(s, [C])
loopnest.prepare_for_codegen()
codegen = torch._C._te.construct_codegen('ir_eval', s, [torch._C._te.BufferArg(x) for x in [A, B, C]])
codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
tA = torch.ones(1, 4)
tB = torch.ones(4, 1)
@ -63,10 +79,41 @@ class TestTensorExprPyBind(JitTestCase):
codegen.call([tA, tB, tC])
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
def test_dynamic_shape(self):
with kernel_arena_scope():
dN = te.VarHandle("n", te.Dtype.Int)
A = te.Placeholder('A', te.Dtype.Double, [dN])
B = te.Placeholder('B', te.Dtype.Double, [dN])
def compute(i):
return A.load([i]) - B.load([i])
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
loopnest = te.LoopNest([C])
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
cg = te.construct_codegen(
'ir_eval',
stmt,
[A, B, C, dN])
def test_with_shape(n):
tA = torch.randn(n, dtype=torch.double)
tB = torch.randn(n, dtype=torch.double)
tC = torch.empty(n, dtype=torch.double)
cg.call([tA, tB, tC, n])
torch.testing.assert_allclose(tA - tB, tC)
test_with_shape(8)
test_with_shape(31)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_tensor_inputs(self):
def f(a, b, c):
return a + b + c
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
@ -94,6 +141,7 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
def test_kernel_with_scalar_inputs(self):
def f(a, b, c):
return a + b + c
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
z = torch.tensor(0.7, dtype=torch.float, device='cpu')