mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Python binding improvements (#59920)
Summary: Some minor quality of life improvements for the NNC python bindings: - expose `call_raw()` - support passing integers to `call()` (for dynamic shapes) - implicit conversions to cleanup `[BufferArg(x) for x in [A, B, C]]` into just `[A, B, C]` - don't silently default to "ir_eval" for unknown mode (e.g. "LLVM") Pull Request resolved: https://github.com/pytorch/pytorch/pull/59920 Reviewed By: ZolotukhinM Differential Revision: D29090904 Pulled By: jansel fbshipit-source-id: 154ace82725ae2046cfe2e6eb324fd37f5d209a7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
68d690ffbd
commit
3d90c82a5c
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch._C._te as te
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
@ -7,6 +8,7 @@ import unittest
|
||||
|
||||
LLVM_ENABLED = torch._C._llvm_enabled()
|
||||
|
||||
|
||||
class kernel_arena_scope(object):
|
||||
def __enter__(self):
|
||||
self.scope = torch._C._te.KernelScope()
|
||||
@ -14,48 +16,62 @@ class kernel_arena_scope(object):
|
||||
def __exit__(self, typ, val, traceback):
|
||||
self.scope = None
|
||||
|
||||
|
||||
def construct_adder(n: int, dtype=te.Dtype.Float):
|
||||
dN = te.ExprHandle.int(n)
|
||||
A = te.Placeholder('A', dtype, [dN])
|
||||
B = te.Placeholder('B', dtype, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) + B.load([i])
|
||||
|
||||
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = te.simplify(loopnest.root_stmt())
|
||||
|
||||
return te.construct_codegen('ir_eval', stmt, [A, B, C])
|
||||
|
||||
|
||||
class TestTensorExprPyBind(JitTestCase):
|
||||
def test_simple_sum(self):
|
||||
with kernel_arena_scope():
|
||||
dtype = torch._C._te.Dtype.Float
|
||||
N = 32
|
||||
dN = torch._C._te.ExprHandle.int(N)
|
||||
n = 32
|
||||
cg = construct_adder(n)
|
||||
|
||||
A = torch._C._te.Placeholder('A', dtype, [dN])
|
||||
B = torch._C._te.Placeholder('B', dtype, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) + B.load([i])
|
||||
C = torch._C._te.Compute('C', [torch._C._te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = torch._C._te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = torch._C._te.simplify(loopnest.root_stmt())
|
||||
|
||||
cg = torch._C._te.construct_codegen('ir_eval', stmt, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
||||
|
||||
tA = torch.rand(N) * 5
|
||||
tB = torch.rand(N) * 6
|
||||
tC = torch.empty(N)
|
||||
tA = torch.randn(n)
|
||||
tB = torch.randn(n)
|
||||
tC = torch.empty(n)
|
||||
cg.call([tA, tB, tC])
|
||||
torch.testing.assert_allclose(tA + tB, tC)
|
||||
|
||||
def test_call_raw(self):
|
||||
with kernel_arena_scope():
|
||||
n = 16
|
||||
cg = construct_adder(n, dtype=te.Dtype.Double)
|
||||
|
||||
tA = torch.randn(n, dtype=torch.float64)
|
||||
tB = torch.randn(n, dtype=torch.float64)
|
||||
tC = torch.empty(n, dtype=torch.float64)
|
||||
cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
|
||||
torch.testing.assert_allclose(tA + tB, tC)
|
||||
|
||||
def test_external_calls(self):
|
||||
with kernel_arena_scope():
|
||||
dtype = torch._C._te.Dtype.Float
|
||||
dtype = te.Dtype.Float
|
||||
|
||||
ZERO = torch._C._te.ExprHandle.int(0)
|
||||
ONE = torch._C._te.ExprHandle.int(1)
|
||||
FOUR = torch._C._te.ExprHandle.int(4)
|
||||
A = torch._C._te.BufHandle('A', [ONE, FOUR], dtype)
|
||||
B = torch._C._te.BufHandle('B', [FOUR, ONE], dtype)
|
||||
C = torch._C._te.BufHandle('C', [ONE, ONE], dtype)
|
||||
ONE = te.ExprHandle.int(1)
|
||||
FOUR = te.ExprHandle.int(4)
|
||||
A = te.BufHandle('A', [ONE, FOUR], dtype)
|
||||
B = te.BufHandle('B', [FOUR, ONE], dtype)
|
||||
C = te.BufHandle('C', [ONE, ONE], dtype)
|
||||
|
||||
s = torch._C._te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
|
||||
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
|
||||
|
||||
loopnest = torch._C._te.LoopNest(s, [C])
|
||||
loopnest = te.LoopNest(s, [C])
|
||||
loopnest.prepare_for_codegen()
|
||||
codegen = torch._C._te.construct_codegen('ir_eval', s, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
||||
codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
|
||||
|
||||
tA = torch.ones(1, 4)
|
||||
tB = torch.ones(4, 1)
|
||||
@ -63,10 +79,41 @@ class TestTensorExprPyBind(JitTestCase):
|
||||
codegen.call([tA, tB, tC])
|
||||
torch.testing.assert_allclose(torch.matmul(tA, tB), tC)
|
||||
|
||||
def test_dynamic_shape(self):
|
||||
with kernel_arena_scope():
|
||||
dN = te.VarHandle("n", te.Dtype.Int)
|
||||
A = te.Placeholder('A', te.Dtype.Double, [dN])
|
||||
B = te.Placeholder('B', te.Dtype.Double, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) - B.load([i])
|
||||
|
||||
C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = te.simplify(loopnest.root_stmt())
|
||||
|
||||
cg = te.construct_codegen(
|
||||
'ir_eval',
|
||||
stmt,
|
||||
[A, B, C, dN])
|
||||
|
||||
def test_with_shape(n):
|
||||
tA = torch.randn(n, dtype=torch.double)
|
||||
tB = torch.randn(n, dtype=torch.double)
|
||||
tC = torch.empty(n, dtype=torch.double)
|
||||
cg.call([tA, tB, tC, n])
|
||||
torch.testing.assert_allclose(tA - tB, tC)
|
||||
|
||||
test_with_shape(8)
|
||||
test_with_shape(31)
|
||||
|
||||
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
|
||||
def test_kernel_with_tensor_inputs(self):
|
||||
def f(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
device, size = 'cpu', (4, 4)
|
||||
x = torch.rand(size, device=device)
|
||||
y = torch.rand(size, device=device)
|
||||
@ -94,6 +141,7 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
|
||||
def test_kernel_with_scalar_inputs(self):
|
||||
def f(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
|
||||
y = torch.tensor(0.6, dtype=torch.float, device='cpu')
|
||||
z = torch.tensor(0.7, dtype=torch.float, device='cpu')
|
||||
|
Reference in New Issue
Block a user