mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add python bindings. (#49698)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49698 Reincarnation of #47620 by jamesr66a. It's just an initial bunch of things that we're exposing to python, more is expected to come in future. Some things can probably be done better, but I'm putting this out anyway, since some other people were interested in using and/or developing this. Differential Revision: D25668694 Test Plan: Imported from OSS Reviewed By: bertmaher Pulled By: ZolotukhinM fbshipit-source-id: fb0fd1b31e851ef9ab724686b9ac2d172fa4905a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9efe15313a
commit
e9dc8fc162
40
test/test_tensorexpr_pybind.py
Normal file
40
test/test_tensorexpr_pybind.py
Normal file
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
class kernel_arena_scope(object):
|
||||
def __enter__(self):
|
||||
self.scope = torch._C.te.KernelScope()
|
||||
|
||||
def __exit__(self, typ, val, traceback):
|
||||
self.scope = None
|
||||
|
||||
class TestTensorExprPyBind(JitTestCase):
|
||||
def test_simple_sum(self):
|
||||
with kernel_arena_scope():
|
||||
dtype = torch._C.te.Dtype.Float
|
||||
N = 32
|
||||
dN = torch._C.te.ExprHandle.int(N)
|
||||
|
||||
A = torch._C.te.Placeholder('A', dtype, [dN])
|
||||
B = torch._C.te.Placeholder('B', dtype, [dN])
|
||||
|
||||
def compute(i):
|
||||
return A.load([i]) + B.load([i])
|
||||
C = torch._C.te.Compute('C', [torch._C.te.DimArg(dN, 'i')], compute)
|
||||
|
||||
loopnest = torch._C.te.LoopNest([C])
|
||||
loopnest.prepare_for_codegen()
|
||||
stmt = torch._C.te.simplify(loopnest.root_stmt())
|
||||
|
||||
cg = torch._C.te.construct_codegen('ir_eval', stmt, [torch._C.te.BufferArg(x) for x in [A, B, C]])
|
||||
|
||||
tA = torch.rand(N) * 5
|
||||
tB = torch.rand(N) * 6
|
||||
tC = torch.empty(N)
|
||||
cg.call([tA, tB, tC])
|
||||
torch.testing.assert_allclose(tA + tB, tC)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Reference in New Issue
Block a user