mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33218 Test Plan: Imported from OSS Differential Revision: D19848378 Pulled By: ZolotukhinM fbshipit-source-id: 48399f8651324d5ad0607e08573d5d7b2026bb23
59 lines
1.7 KiB
C++
59 lines
1.7 KiB
C++
#include "test/cpp/tensorexpr/test_base.h"
|
|
|
|
#include "test/cpp/tensorexpr/test_utils.h"
|
|
#include "torch/csrc/jit/tensorexpr/buffer.h"
|
|
#include "torch/csrc/jit/tensorexpr/ir.h"
|
|
|
|
#include <cmath>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
void testExprVectorAdd01() {
|
|
KernelScope kernel_scope;
|
|
const int kVectorSize = 8;
|
|
const int kVectorCount = 128;
|
|
const int kTotalSize = kVectorSize * kVectorCount;
|
|
|
|
Buffer a_buf(Var("A", kHandle), kFloat32, {Expr(kTotalSize)});
|
|
Buffer b_buf(Var("B", kHandle), kFloat32, {Expr(kTotalSize)});
|
|
Buffer c_buf(Var("C", kHandle), kFloat32, {Expr(kTotalSize)});
|
|
|
|
/*
|
|
Build the following:
|
|
for (int index = 0; index < kVectorCount; index++) {
|
|
store(c_buf, ramp(index * 8, 1, 8),
|
|
load(a_buf, ramp(index * 8, 1, 8) +
|
|
load(b_buf, ramp(index * 8, 1, 8))))
|
|
}
|
|
*/
|
|
Var index = Var("index", kInt32);
|
|
Expr load_a = Load::make(
|
|
a_buf,
|
|
Ramp::make(index * kVectorSize, 1, kVectorSize),
|
|
Broadcast::make(1, kVectorSize));
|
|
Expr load_b = Load::make(
|
|
b_buf,
|
|
Ramp::make(index * kVectorSize, 1, kVectorSize),
|
|
Broadcast::make(1, kVectorSize));
|
|
Expr value = load_a + load_b;
|
|
Stmt store_c = Store::make(
|
|
c_buf,
|
|
Ramp::make(index * kVectorSize, 1, kVectorSize),
|
|
value,
|
|
Broadcast::make(1, kVectorSize));
|
|
Stmt stmt = For::make(index, 0, kVectorCount, store_c);
|
|
|
|
EXPECT_EQ(load_a.dtype(), Dtype(kFloat32, kVectorSize));
|
|
EXPECT_EQ(load_b.dtype(), Dtype(kFloat32, kVectorSize));
|
|
EXPECT_EQ(value.dtype(), Dtype(kFloat32, kVectorSize));
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|