mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Original commit changeset: e1ea3b3630d1 Test Plan: Broke tests, e.g. T85754010 Reviewed By: ZolotukhinM Differential Revision: D26727166 fbshipit-source-id: 42d2090bc55ec2982a4a08c38278c80617d5398a
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import torch
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
class kernel_arena_scope(object):
|
|
def __enter__(self):
|
|
self.scope = torch._C._te.KernelScope()
|
|
|
|
def __exit__(self, typ, val, traceback):
|
|
self.scope = None
|
|
|
|
class TestTensorExprPyBind(JitTestCase):
|
|
def test_simple_sum(self):
|
|
with kernel_arena_scope():
|
|
dtype = torch._C._te.Dtype.Float
|
|
N = 32
|
|
dN = torch._C._te.ExprHandle.int(N)
|
|
|
|
A = torch._C._te.Placeholder('A', dtype, [dN])
|
|
B = torch._C._te.Placeholder('B', dtype, [dN])
|
|
|
|
def compute(i):
|
|
return A.load([i]) + B.load([i])
|
|
C = torch._C._te.Compute('C', [torch._C._te.DimArg(dN, 'i')], compute)
|
|
|
|
loopnest = torch._C._te.LoopNest([C])
|
|
loopnest.prepare_for_codegen()
|
|
stmt = torch._C._te.simplify(loopnest.root_stmt())
|
|
|
|
cg = torch._C._te.construct_codegen('ir_eval', stmt, [torch._C._te.BufferArg(x) for x in [A, B, C]])
|
|
|
|
tA = torch.rand(N) * 5
|
|
tB = torch.rand(N) * 6
|
|
tC = torch.empty(N)
|
|
cg.call([tA, tB, tC])
|
|
torch.testing.assert_allclose(tA + tB, tC)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|