mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add CUDA codegen. (#34227)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34227 This PR adds a CUDA support to tensor expressions. Differential Revision: D20251836 Test Plan: Imported from OSS Pulled By: ZolotukhinM fbshipit-source-id: ab36a55834cceff30c8371fef6cca1054a32f017
This commit is contained in:
committed by
Facebook GitHub Bot
parent
42b2c8c65d
commit
35e7efeb9a
@ -557,6 +557,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp
|
||||
)
|
||||
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
|
||||
target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
|
||||
@ -574,6 +575,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp
|
||||
)
|
||||
if (USE_NCCL)
|
||||
list(APPEND Caffe2_HIP_SRCS
|
||||
|
@ -12,5 +12,14 @@ namespace jit {
|
||||
TH_FORALL_TESTS(TENSOREXPR_GTEST)
|
||||
#undef TENSOREXPR_GTEST
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#define TENSOREXPR_GTEST_CUDA(name) \
|
||||
TEST(TensorExprTest, name##_CUDA) { \
|
||||
test##name(); \
|
||||
}
|
||||
TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
|
||||
#undef TENSOREXPR_GTEST_CUDA
|
||||
#endif
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
333
test/cpp/tensorexpr/test_cuda.cpp
Normal file
333
test/cpp/tensorexpr/test_cuda.cpp
Normal file
@ -0,0 +1,333 @@
|
||||
#ifdef USE_CUDA
|
||||
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include "test/cpp/tensorexpr/test_base.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "test/cpp/tensorexpr/padded_buffer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/buffer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
|
||||
#include "torch/csrc/jit/tensorexpr/schedule.h"
|
||||
#include "torch/csrc/jit/tensorexpr/tensor.h"
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
using namespace torch::jit::tensorexpr::schedule;
|
||||
|
||||
template <typename ctype>
|
||||
void testCudaTestVectorAdd01_impl() {
|
||||
KernelScope kernel_scope;
|
||||
const int num_iter = 3;
|
||||
const int block_count = 16;
|
||||
const int block_size = 128;
|
||||
Dtype dtype = ToDtype<ctype>();
|
||||
Buffer a_buf("a", dtype, {num_iter, block_count, block_size});
|
||||
Buffer b_buf("b", dtype, {num_iter, block_count, block_size});
|
||||
Tensor* c = Compute(
|
||||
"c",
|
||||
{
|
||||
{num_iter, "n"},
|
||||
{block_count, "b_id"},
|
||||
{block_size, "t_id"},
|
||||
},
|
||||
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
|
||||
return a_buf(n, b_id, t_id) + b_buf(n, b_id, t_id);
|
||||
});
|
||||
LoopNest l({c});
|
||||
std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
|
||||
l.SetGPUBlockIndex(loops[1], 0);
|
||||
l.SetGPUThreadIndex(loops[2], 0);
|
||||
Stmt* stmt = l.root_stmt();
|
||||
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
|
||||
const int N = block_count * block_size * num_iter;
|
||||
PaddedBuffer<ctype> a_v(N);
|
||||
PaddedBuffer<ctype> b_v(N);
|
||||
PaddedBuffer<ctype> c_v(N);
|
||||
PaddedBuffer<ctype> c_ref(N);
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
a_v(i) = ctype(i);
|
||||
b_v(i) = ctype(i * 3 + 7);
|
||||
c_ref(i) = a_v(i) + b_v(i);
|
||||
}
|
||||
|
||||
// TODO: move gpu support into PaddedBuffer
|
||||
ctype* a_dev = nullptr;
|
||||
cudaMalloc(&a_dev, N * sizeof(ctype));
|
||||
ctype* b_dev = nullptr;
|
||||
cudaMalloc(&b_dev, N * sizeof(ctype));
|
||||
ctype* c_dev = nullptr;
|
||||
cudaMalloc(&c_dev, N * sizeof(ctype));
|
||||
cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cuda_cg(c_dev, a_dev, b_dev);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
ExpectAllNear(c_v, c_ref, 1e-5);
|
||||
|
||||
cudaFree(a_dev);
|
||||
cudaFree(b_dev);
|
||||
cudaFree(c_dev);
|
||||
}
|
||||
|
||||
void testCudaTestVectorAdd01() {
|
||||
// floating types.
|
||||
testCudaTestVectorAdd01_impl<float>();
|
||||
testCudaTestVectorAdd01_impl<at::Half>();
|
||||
testCudaTestVectorAdd01_impl<double>();
|
||||
|
||||
// integer types.
|
||||
testCudaTestVectorAdd01_impl<int8_t>();
|
||||
testCudaTestVectorAdd01_impl<uint8_t>();
|
||||
testCudaTestVectorAdd01_impl<int16_t>();
|
||||
testCudaTestVectorAdd01_impl<int32_t>();
|
||||
testCudaTestVectorAdd01_impl<int64_t>();
|
||||
}
|
||||
|
||||
static void testCudaTestVectorAdd02_impl(int N, int block_size) {
|
||||
KernelScope kernel_scope;
|
||||
Buffer a_buf("a", kFloat, {N});
|
||||
Buffer b_buf("b", kFloat, {N});
|
||||
Tensor* c = Compute(
|
||||
"c",
|
||||
{
|
||||
{N, "N"},
|
||||
},
|
||||
[&](const VarHandle& n) { return a_buf(n) + b_buf(n); });
|
||||
LoopNest l({c});
|
||||
Stmt* n_outer;
|
||||
Stmt* n_inner;
|
||||
std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
|
||||
l.SplitWithMask(loops[0], block_size, &n_outer, &n_inner);
|
||||
l.SetGPUBlockIndex(n_outer, 0);
|
||||
l.SetGPUThreadIndex(n_inner, 0);
|
||||
Stmt* stmt = l.root_stmt();
|
||||
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
|
||||
PaddedBuffer<float> a_v(N);
|
||||
PaddedBuffer<float> b_v(N);
|
||||
PaddedBuffer<float> c_v(N);
|
||||
PaddedBuffer<float> c_ref(N);
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
a_v(i) = i;
|
||||
b_v(i) = i * 3 + 7;
|
||||
c_ref(i) = a_v(i) + b_v(i);
|
||||
}
|
||||
|
||||
// TODO: move gpu support into PaddedBuffer
|
||||
float* a_dev = nullptr;
|
||||
cudaMalloc(&a_dev, N * sizeof(float));
|
||||
float* b_dev = nullptr;
|
||||
cudaMalloc(&b_dev, N * sizeof(float));
|
||||
float* c_dev = nullptr;
|
||||
cudaMalloc(&c_dev, N * sizeof(float));
|
||||
cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cuda_cg(c_dev, a_dev, b_dev);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
ExpectAllNear(c_v, c_ref, 1e-5);
|
||||
|
||||
cudaFree(a_dev);
|
||||
cudaFree(b_dev);
|
||||
cudaFree(c_dev);
|
||||
}
|
||||
|
||||
void testCudaTestVectorAdd02() {
|
||||
testCudaTestVectorAdd02_impl(1024, 128);
|
||||
testCudaTestVectorAdd02_impl(1030, 128);
|
||||
}
|
||||
|
||||
void testCudaDynamicShape2D() {
|
||||
KernelScope kernel_scope;
|
||||
auto testWithSize = [](int32_t M, int32_t N) {
|
||||
VarHandle m("m", kInt);
|
||||
VarHandle n("n", kInt);
|
||||
Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
|
||||
Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
|
||||
Tensor* c = Compute(
|
||||
"c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
|
||||
return a(i, j) + b(i, j);
|
||||
});
|
||||
LoopNest l({c});
|
||||
Stmt* s = l.root_stmt();
|
||||
CudaCodeGen cg(s, {a, b, c, m, n});
|
||||
|
||||
std::vector<float> aData(M * N, 1.0f);
|
||||
std::vector<float> bData(M * N, 2.0f);
|
||||
std::vector<float> cData(M * N, 0.0f);
|
||||
float* aDev = nullptr;
|
||||
float* bDev = nullptr;
|
||||
float* cDev = nullptr;
|
||||
cudaMalloc(&aDev, aData.size() * sizeof(aData[0]));
|
||||
cudaMalloc(&bDev, bData.size() * sizeof(bData[0]));
|
||||
cudaMalloc(&cDev, cData.size() * sizeof(cData[0]));
|
||||
cudaMemcpy(
|
||||
aDev,
|
||||
aData.data(),
|
||||
aData.size() * sizeof(aData[0]),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(
|
||||
bDev,
|
||||
bData.data(),
|
||||
bData.size() * sizeof(bData[0]),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(
|
||||
cDev,
|
||||
cData.data(),
|
||||
cData.size() * sizeof(cData[0]),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cg.call({aDev, bDev, cDev, M, N});
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaMemcpy(
|
||||
cData.data(),
|
||||
cDev,
|
||||
cData.size() * sizeof(cData[0]),
|
||||
cudaMemcpyDeviceToHost);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
|
||||
|
||||
cudaFree(aDev);
|
||||
cudaFree(bDev);
|
||||
cudaFree(cDev);
|
||||
};
|
||||
testWithSize(32, 32);
|
||||
testWithSize(1, 16);
|
||||
testWithSize(27, 13);
|
||||
}
|
||||
|
||||
void testCudaTestRand01() {
|
||||
KernelScope kernel_scope;
|
||||
const int num_iter = 3;
|
||||
const int block_count = 16;
|
||||
const int block_size = 128;
|
||||
Tensor* c = Compute(
|
||||
"c",
|
||||
{
|
||||
{num_iter, "n"},
|
||||
{block_count, "b_id"},
|
||||
{block_size, "t_id"},
|
||||
},
|
||||
[&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
|
||||
return Intrinsics::make(IntrinsicsOp::kRand, kFloat);
|
||||
});
|
||||
LoopNest l({c});
|
||||
std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
|
||||
l.SetGPUBlockIndex(loops[1], 0);
|
||||
l.SetGPUThreadIndex(loops[2], 0);
|
||||
Stmt* stmt = l.root_stmt();
|
||||
CudaCodeGen cuda_cg(stmt, c);
|
||||
const int N = block_count * block_size * num_iter;
|
||||
PaddedBuffer<float> c_v(N);
|
||||
|
||||
// TODO: move gpu support into PaddedBuffer
|
||||
float* c_dev = nullptr;
|
||||
cudaMalloc(&c_dev, N * sizeof(float));
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cuda_cg(c_dev);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
float sum1 = 0;
|
||||
float sum2 = 0;
|
||||
float sum3 = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
float v = c_v.data()[i];
|
||||
sum1 += v;
|
||||
sum2 += v * v;
|
||||
sum3 += v * v * v;
|
||||
EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v;
|
||||
}
|
||||
sum1 /= N;
|
||||
sum2 /= N;
|
||||
sum3 /= N;
|
||||
float sum1_mean = 1.f / 2;
|
||||
float sum2_mean = 1.f / 3;
|
||||
float sum3_mean = 1.f / 4;
|
||||
|
||||
EXPECT_NEAR(sum1, sum1_mean, 2e-2);
|
||||
EXPECT_NEAR(sum2, sum2_mean, 2e-2);
|
||||
EXPECT_NEAR(sum3, sum3_mean, 2e-2);
|
||||
cudaFree(c_dev);
|
||||
}
|
||||
|
||||
void testCudaDynamicShapeSplit() {
|
||||
KernelScope ks;
|
||||
constexpr int N = 4096;
|
||||
VarHandle n("n", kInt);
|
||||
Buffer a(VarHandle("a", kHandle), kFloat, {n});
|
||||
Tensor* b =
|
||||
Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; });
|
||||
LoopNest l({b});
|
||||
Stmt* outer;
|
||||
Stmt* inner;
|
||||
std::vector<Stmt*> loops = l.getLoopStmtsFor(b);
|
||||
l.SplitWithMask(loops[0], 1024, &outer, &inner);
|
||||
l.SetGPUBlockIndex(outer, 0);
|
||||
l.SetGPUThreadIndex(inner, 0);
|
||||
Stmt* s = l.root_stmt();
|
||||
CudaCodeGen cg(s, {a, b, n});
|
||||
|
||||
std::vector<float> aData(N, 1.0f);
|
||||
std::vector<float> bData(N, 1.0f);
|
||||
float* aDev = nullptr;
|
||||
float* bDev = nullptr;
|
||||
cudaMalloc(&aDev, aData.size() * sizeof(aData[0]));
|
||||
cudaMalloc(&bDev, bData.size() * sizeof(bData[0]));
|
||||
cudaMemcpy(
|
||||
aDev,
|
||||
aData.data(),
|
||||
aData.size() * sizeof(aData[0]),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(
|
||||
bDev,
|
||||
bData.data(),
|
||||
bData.size() * sizeof(aData[0]),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cg.call({aDev, bDev, N});
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaMemcpy(
|
||||
bData.data(),
|
||||
bDev,
|
||||
bData.size() * sizeof(aData[0]),
|
||||
cudaMemcpyDeviceToHost);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
ExpectAllNear(bData, std::vector<float>(N, 2.0f), 1e-7);
|
||||
|
||||
cudaFree(aDev);
|
||||
cudaFree(bDev);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
#endif
|
@ -87,6 +87,11 @@ namespace jit {
|
||||
_(ATenltInt)
|
||||
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(CudaTestVectorAdd01) \
|
||||
_(CudaTestVectorAdd02) \
|
||||
_(CudaDynamicShape2D) \
|
||||
_(CudaTestRand01) \
|
||||
_(CudaDynamicShapeSplit)
|
||||
|
||||
#define DECLARE_TENSOREXPR_TEST(name) void test##name();
|
||||
TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST)
|
||||
|
@ -34,6 +34,16 @@ class ExecutionCounter(object):
|
||||
return value - self.start_value
|
||||
|
||||
|
||||
class CudaCodeGenCreated(ExecutionCounter):
|
||||
def __init__(self):
|
||||
super(CudaCodeGenCreated, self).__init__("cuda_codegen_created")
|
||||
|
||||
|
||||
class CudaCodeGenExecuted(ExecutionCounter):
|
||||
def __init__(self):
|
||||
super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed")
|
||||
|
||||
|
||||
class SimpleIREvalExecuted(ExecutionCounter):
|
||||
def __init__(self):
|
||||
super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed")
|
||||
@ -80,7 +90,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
c = torch.addcmul(torch.add(x, y), z, w)
|
||||
return c
|
||||
|
||||
device_options = ["cpu"]
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
|
||||
for dev in device_options:
|
||||
rand_a = torch.rand(1024, dtype=torch.float, device=dev)
|
||||
rand_b = torch.rand(1024, dtype=torch.float, device=dev)
|
||||
@ -102,6 +112,79 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
|
||||
|
||||
|
||||
def test_three_arg_cuda(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
cuda_cg_executed = CudaCodeGenExecuted()
|
||||
cuda_cg_created = CudaCodeGenCreated()
|
||||
|
||||
def test(x, y, z):
|
||||
aaa = torch.add(x, y)
|
||||
bbb = torch.add(aaa, z)
|
||||
return bbb
|
||||
|
||||
M = 32
|
||||
N = 32
|
||||
traced = torch.jit.trace(
|
||||
test,
|
||||
(
|
||||
torch.rand(M, N, device="cuda"),
|
||||
torch.rand(M, N, device="cuda"),
|
||||
torch.rand(M, N, device="cuda"),
|
||||
),
|
||||
)
|
||||
|
||||
a = torch.rand(M, N, device="cuda")
|
||||
b = torch.rand(M, N, device="cuda")
|
||||
c = torch.rand(M, N, device="cuda")
|
||||
x = traced(a, b, c)
|
||||
npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
|
||||
np.testing.assert_allclose(npr, x.cpu().numpy())
|
||||
assert cuda_cg_executed.elapsed_value() >= 1
|
||||
assert cuda_cg_created.elapsed_value() >= 1
|
||||
|
||||
|
||||
def test_broadcast_cuda(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
def test_body(M, N, L, K):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
cuda_cg_executed = CudaCodeGenExecuted()
|
||||
cuda_cg_created = CudaCodeGenCreated()
|
||||
|
||||
def test(x, y, z):
|
||||
v1 = torch.add(x, y)
|
||||
v2 = torch.add(v1, z)
|
||||
return v2
|
||||
|
||||
a_shape = [M, N]
|
||||
b_shape = [L, M, 1]
|
||||
c_shape = [K, L, 1, 1]
|
||||
traced = torch.jit.trace(
|
||||
test,
|
||||
(
|
||||
torch.rand(*a_shape, device="cuda"),
|
||||
torch.rand(*b_shape, device="cuda"),
|
||||
torch.rand(*c_shape, device="cuda"),
|
||||
),
|
||||
)
|
||||
|
||||
a = torch.rand(*a_shape, device="cuda")
|
||||
b = torch.rand(*b_shape, device="cuda")
|
||||
c = torch.rand(*c_shape, device="cuda")
|
||||
x = traced(a, b, c)
|
||||
npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
|
||||
np.testing.assert_allclose(npr, x.cpu().numpy())
|
||||
assert cuda_cg_executed.elapsed_value() >= 1
|
||||
assert cuda_cg_created.elapsed_value() >= 1
|
||||
|
||||
test_configs = [[36, 17, 63, 33], [32, 32, 32, 32]]
|
||||
for test_config in test_configs:
|
||||
test_body(*test_config)
|
||||
|
||||
|
||||
def test_all_combos(self):
|
||||
def easy(x, y, z):
|
||||
a = torch.add(x, y)
|
||||
@ -426,7 +509,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
c = torch.lt(x, y)
|
||||
return c
|
||||
|
||||
device_options = ["cpu"]
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
for dev in device_options:
|
||||
traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
|
||||
a = torch.ones(1024, dtype=torch.int32, device=dev)
|
||||
@ -451,7 +534,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
def test(x):
|
||||
return torch.clamp(x + 3.0, 0.0, 6.0)
|
||||
|
||||
device_options = ["cpu"]
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
|
||||
for dev in device_options:
|
||||
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
|
||||
@ -463,7 +546,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
def test(x):
|
||||
return torch.clamp(F.relu(x), 0, 0.5)
|
||||
|
||||
device_options = ["cpu"]
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
for dev in device_options:
|
||||
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
|
||||
a = 20.0 * torch.rand(1024, device=dev) - 10.0
|
||||
@ -598,7 +681,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
# test_tanh_backward,
|
||||
test_type_as,
|
||||
}
|
||||
device_options = ["cpu"]
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
|
||||
for torch_fn in fns:
|
||||
for dev in device_options:
|
||||
rand_a = torch.rand(1024, device=dev)
|
||||
@ -776,7 +859,7 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
test_neg,
|
||||
test_relu,
|
||||
}
|
||||
device_options = ["cpu"]
|
||||
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
|
||||
|
||||
for torch_fn in fns:
|
||||
for dev in device_options:
|
||||
@ -797,6 +880,26 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
|
||||
|
||||
|
||||
def test_rand_like(self):
|
||||
devices = ["cuda"] if torch.cuda.is_available() else []
|
||||
N = 1 << 16
|
||||
|
||||
def run_rand_like(x, y):
|
||||
return torch.rand_like(torch.add(x, y))
|
||||
|
||||
for device in devices:
|
||||
x = torch.rand(N, device=device)
|
||||
traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
|
||||
x_v = traced(x, x)
|
||||
x_np = x.cpu().numpy()
|
||||
x1_mean = np.mean(x_np)
|
||||
x2_mean = np.mean(x_np ** 2)
|
||||
x3_mean = np.mean(x_np ** 3)
|
||||
np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
|
||||
np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
|
||||
np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
|
||||
|
||||
|
||||
def test_nans(self):
|
||||
def test_max(x, y):
|
||||
return torch.max(2 * x, 2 * y)
|
||||
@ -898,6 +1001,10 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
def test_cat_cpu(self):
|
||||
self._test_cat('cpu')
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
||||
def test_cat_cuda(self):
|
||||
self._test_cat('cuda')
|
||||
|
||||
def test_scalar(self):
|
||||
@torch.jit.script
|
||||
def test_float(x, y, z, a, b):
|
||||
@ -1001,8 +1108,66 @@ class TestTensorExprFuser(BaseTestClass):
|
||||
assert interp.elapsed_value() == 1
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
|
||||
@unittest.skip("dynamic shapes are not quite there yet")
|
||||
def test_dynamic_shape(self):
|
||||
with num_profiled_runs(2):
|
||||
@torch.jit.script
|
||||
def test(x, y, z):
|
||||
return x * y * z
|
||||
cuda = CudaCodeGenCreated()
|
||||
x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)]
|
||||
ref = test(x, y, z)
|
||||
_ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
|
||||
res = test(x, y, z)
|
||||
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
|
||||
assert cuda.elapsed_value() == 1
|
||||
|
||||
# A wild broadcast appears.
|
||||
x = torch.rand(4, 8).cuda()
|
||||
y = torch.rand(1, 8).cuda()
|
||||
z = torch.rand(4, 1).cuda()
|
||||
res = test(x, y, z)
|
||||
xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
|
||||
np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
|
||||
assert cuda.elapsed_value() == 1
|
||||
|
||||
# Mismatched shapes shouldn't reach codegen.
|
||||
x = torch.rand(4, 8).cuda()
|
||||
y = torch.rand(4, 8).cuda()
|
||||
z = torch.rand(5, 8).cuda()
|
||||
try:
|
||||
res = test(x, y, z)
|
||||
except RuntimeError as e:
|
||||
assert "The size of tensor a (4) must match" in e.args[0]
|
||||
assert cuda.elapsed_value() == 1
|
||||
|
||||
# Changing a static dimension fails guards.
|
||||
# x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
|
||||
# xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
|
||||
# res = test(x, y, z)
|
||||
# print(test.graph_for(x, y, z))
|
||||
# np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
|
||||
# assert cuda.elapsed_value() == 1
|
||||
|
||||
@unittest.skip("guarding on static shapes is not working")
|
||||
def test_guard_fails(self):
|
||||
@torch.jit.script
|
||||
def test(x, y, z):
|
||||
return x * y * z
|
||||
cuda = CudaCodeGenExecuted()
|
||||
r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
||||
assert cuda.elapsed_value() == 0
|
||||
r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
||||
assert cuda.elapsed_value() == 1
|
||||
r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
|
||||
assert cuda.elapsed_value() == 2
|
||||
r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
|
||||
print(test.graph_for(*[torch.rand(7).cuda() for _ in range(3)]))
|
||||
assert cuda.elapsed_value() == 2
|
||||
|
||||
def test_bitwise_ops(self):
|
||||
devices = ["cpu"]
|
||||
devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]
|
||||
|
||||
def run_and(x, y):
|
||||
return x & (x & y)
|
||||
|
@ -219,6 +219,7 @@ libtorch_cuda_sources = [
|
||||
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
|
||||
"torch/csrc/autograd/profiler_cuda.cpp",
|
||||
"torch/csrc/autograd/functions/comm.cpp",
|
||||
"torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
|
||||
]
|
||||
|
||||
torch_cpp_srcs = [
|
||||
|
@ -60,6 +60,7 @@
|
||||
#include <torch/csrc/jit/python/python_tree_views.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
@ -414,6 +415,42 @@ void initJITBindings(PyObject* module) {
|
||||
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
|
||||
return trigger->value();
|
||||
})
|
||||
.def(
|
||||
"_jit_get_te_cuda_pointwise_loop_levels",
|
||||
[]() -> int {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
return GetTECudaPointwiseLoopLevels();
|
||||
})
|
||||
.def(
|
||||
"_jit_set_te_cuda_pointwise_loop_levels",
|
||||
[](int level) {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
return GetTECudaPointwiseLoopLevels() = level;
|
||||
})
|
||||
.def(
|
||||
"_jit_get_te_cuda_pointwise_block_count",
|
||||
[]() -> int {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
return GetTECudaPointwiseBlockCount();
|
||||
})
|
||||
.def(
|
||||
"_jit_set_te_cuda_pointwise_block_count",
|
||||
[](int block_count) {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
return GetTECudaPointwiseBlockCount() = block_count;
|
||||
})
|
||||
.def(
|
||||
"_jit_get_te_cuda_pointwise_block_size",
|
||||
[]() -> int {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
return GetTECudaPointwiseBlockSize();
|
||||
})
|
||||
.def(
|
||||
"_jit_set_te_cuda_pointwise_block_size",
|
||||
[](int block_size) {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
return GetTECudaPointwiseBlockSize() = block_size;
|
||||
})
|
||||
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
|
||||
.def(
|
||||
"_jit_fuser_get_fused_kernel_code",
|
||||
|
695
torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Normal file
695
torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Normal file
@ -0,0 +1,695 @@
|
||||
#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
|
||||
#include "torch/csrc/jit/tensorexpr/cuda_half_support.h"
|
||||
|
||||
#include "ATen/CUDAGenerator.h"
|
||||
#include "c10/cuda/CUDAFunctions.h"
|
||||
#include "torch/csrc/jit/tensorexpr/cuda_random.h"
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/execution_counter.h"
|
||||
|
||||
#define DEBUG_PRINT 0
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
DEFINE_TRIGGER(cuda_codegen_created);
|
||||
DEFINE_TRIGGER(cuda_codegen_executed);
|
||||
|
||||
// A RAII wrapper to manage a variable and name pair in the look-up table.
|
||||
// TODO: move this to a more shared place.
|
||||
class ScopedVarName {
|
||||
public:
|
||||
ScopedVarName(VarNameMap* mapping, const Var* var, const std::string& name)
|
||||
: mapping_(mapping), var_(var) {
|
||||
auto iter = mapping->find(var);
|
||||
if (iter != mapping->end()) {
|
||||
throw std::runtime_error("Duplicate var entry: " + var->name_hint());
|
||||
}
|
||||
mapping->insert(std::make_pair(var, name));
|
||||
}
|
||||
|
||||
ScopedVarName(
|
||||
UniqueNameManager* manager,
|
||||
const Var* var,
|
||||
const std::string& name)
|
||||
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
|
||||
|
||||
ScopedVarName(const ScopedVarName&) = delete;
|
||||
ScopedVarName& operator=(const ScopedVarName&) = delete;
|
||||
|
||||
~ScopedVarName() noexcept(false) {
|
||||
mapping_->erase(var_);
|
||||
}
|
||||
|
||||
private:
|
||||
VarNameMap* mapping_ = nullptr;
|
||||
const Var* var_ = nullptr;
|
||||
};
|
||||
|
||||
static int as_int(const Expr* expr) {
|
||||
auto v = dynamic_cast<const IntImm*>(expr);
|
||||
TORCH_CHECK(v, "Expression is not an integer constant");
|
||||
return v->value();
|
||||
}
|
||||
|
||||
static bool is_zero(const Expr* expr) {
|
||||
return as_int(expr) == 0;
|
||||
}
|
||||
|
||||
static const at::cuda::NVRTC& nvrtc() {
|
||||
return at::globalContext().getNVRTC();
|
||||
}
|
||||
|
||||
static void getMajorMinor(
|
||||
const cudaDeviceProp* const prop,
|
||||
int& major,
|
||||
int& minor) {
|
||||
using CudaVersion = std::pair<int, int>;
|
||||
CudaVersion nvrtc_version;
|
||||
AT_CUDA_NVRTC_CHECK(
|
||||
nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
|
||||
|
||||
AT_ASSERT(nvrtc_version.first >= 6);
|
||||
|
||||
CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
|
||||
CudaVersion max_dev_version(dev_version);
|
||||
if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
|
||||
max_dev_version = CudaVersion(5, 0);
|
||||
} else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
|
||||
max_dev_version = CudaVersion(6, 0);
|
||||
} else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
|
||||
max_dev_version = CudaVersion(7, 2);
|
||||
} else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
|
||||
max_dev_version = CudaVersion(7, 5);
|
||||
}
|
||||
if (dev_version > max_dev_version) {
|
||||
dev_version = max_dev_version;
|
||||
}
|
||||
major = dev_version.first;
|
||||
minor = dev_version.second;
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const For* v) {
|
||||
const LoopOptions& loop_options = v->loop_options();
|
||||
if (loop_options.is_gpu_block_index()) {
|
||||
ScopedVarName var_name(
|
||||
name_manager(), v->var(), loop_options.gpu_block_index_str());
|
||||
v->body()->accept(this);
|
||||
int gpu_block_index = loop_options.gpu_block_index();
|
||||
if (gpu_block_extents_.size() <= gpu_block_index) {
|
||||
gpu_block_extents_.resize(gpu_block_index + 1);
|
||||
}
|
||||
if (!is_zero(v->start())) {
|
||||
throw std::runtime_error(
|
||||
"start must be zero for gpu_block_index: " +
|
||||
std::to_string(ExprHandle(v->start())));
|
||||
}
|
||||
gpu_block_extents_[gpu_block_index] = v->stop();
|
||||
} else if (loop_options.is_gpu_thread_index()) {
|
||||
ScopedVarName var_name(
|
||||
name_manager(), v->var(), loop_options.gpu_thread_index_str());
|
||||
v->body()->accept(this);
|
||||
int gpu_thread_index = loop_options.gpu_thread_index();
|
||||
if (gpu_thread_extents_.size() <= gpu_thread_index) {
|
||||
gpu_thread_extents_.resize(gpu_thread_index + 1);
|
||||
}
|
||||
if (!is_zero(v->start())) {
|
||||
throw std::runtime_error(
|
||||
"start must be zero for gpu_block_index: " +
|
||||
std::to_string(ExprHandle(v->start())));
|
||||
}
|
||||
gpu_thread_extents_[gpu_thread_index] = v->stop();
|
||||
} else {
|
||||
IRPrinter::visit(v);
|
||||
}
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Intrinsics* v) {
|
||||
if (v->op_type() == IntrinsicsOp::kRand) {
|
||||
os() << "Uint32ToFloat(" << *rand_func_ << "())";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string func_name = v->func_name();
|
||||
|
||||
// get type of resulting expression.
|
||||
ScalarType returnType = v->param(0)->dtype().scalar_type();
|
||||
for (int i = 1; i < v->nparams(); ++i) {
|
||||
returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
|
||||
}
|
||||
|
||||
if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
|
||||
func_name = func_name + "f";
|
||||
}
|
||||
|
||||
os() << func_name << "(";
|
||||
for (int i = 0; i < v->nparams(); i++) {
|
||||
if (i > 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
os() << *v->param(i);
|
||||
}
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Load* v) {
|
||||
// TODO: find a better metric in using ldg or not. Support different dtypes.
|
||||
if (v->dtype().scalar_type() == ScalarType::Half) {
|
||||
os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])";
|
||||
} else {
|
||||
os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")";
|
||||
}
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Store* v) {
|
||||
os() << *v->base_handle() << "[" << *v->index() << "] = ";
|
||||
if (v->value()->dtype().scalar_type() == ScalarType::Half) {
|
||||
os() << "__float2half(" << *v->value() << ");";
|
||||
} else {
|
||||
os() << *v->value() << ";";
|
||||
}
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Max* v) {
|
||||
auto dtype = v->dtype().scalar_type();
|
||||
switch (dtype) {
|
||||
case ScalarType::Half:
|
||||
// doing Half math in float.
|
||||
case ScalarType::Float:
|
||||
os() << "fmaxf";
|
||||
break;
|
||||
case ScalarType::Double:
|
||||
os() << "fmax";
|
||||
break;
|
||||
default:
|
||||
os() << "max";
|
||||
break;
|
||||
}
|
||||
os() << "(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ",";
|
||||
v->rhs()->accept(this);
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Min* v) {
|
||||
auto dtype = v->dtype().scalar_type();
|
||||
switch (dtype) {
|
||||
case ScalarType::Half:
|
||||
// doing Half math in float.
|
||||
case ScalarType::Float:
|
||||
os() << "fminf";
|
||||
break;
|
||||
case ScalarType::Double:
|
||||
os() << "fmin";
|
||||
break;
|
||||
default:
|
||||
os() << "min";
|
||||
break;
|
||||
}
|
||||
os() << "(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ",";
|
||||
v->rhs()->accept(this);
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
std::string cudaDtypeCppString(const Dtype& dtype) {
|
||||
switch (dtype.scalar_type()) {
|
||||
case ScalarType::Half:
|
||||
return "half";
|
||||
case ScalarType::Char:
|
||||
return "char";
|
||||
case ScalarType::Byte:
|
||||
return "unsigned char";
|
||||
case ScalarType::Short:
|
||||
return "short";
|
||||
case ScalarType::Long:
|
||||
return "long";
|
||||
default:; /* nothing */
|
||||
}
|
||||
return dtype.ToCppString();
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const LetStmt* v) {
|
||||
const Var* var = v->var();
|
||||
if (var->dtype().scalar_type() == ScalarType::Half) {
|
||||
// we do math in floats so use that.
|
||||
os() << "float";
|
||||
} else {
|
||||
os() << cudaDtypeCppString(var->dtype());
|
||||
}
|
||||
os() << " " << *var << " = " << *v->value() << "; " << std::endl;
|
||||
v->body()->accept(this);
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const IfThenElse* v) {
|
||||
os() << "((";
|
||||
v->condition()->accept(this);
|
||||
os() << ") ? ";
|
||||
v->true_value()->accept(this);
|
||||
os() << " : ";
|
||||
v->false_value()->accept(this);
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
class PrioritizeLoad : public IRMutator {
|
||||
public:
|
||||
const Expr* mutate(const Load* v) override {
|
||||
// Look at the declaration of this variable for more details.
|
||||
if (nested_if_then_else_ > 0) {
|
||||
return IRMutator::mutate(v);
|
||||
}
|
||||
MemLoadList& load_list = load_stack_.back();
|
||||
const Var* load_new_var = new Var("v", v->dtype());
|
||||
const Expr* new_value = IRMutator::mutate(v);
|
||||
load_list.push_back(std::make_pair(load_new_var, new_value));
|
||||
return load_new_var;
|
||||
}
|
||||
|
||||
// TODO: merge this with the IRMutator::mutate version.
|
||||
Stmt* mutate(const For* v) override {
|
||||
const Var* var = v->var();
|
||||
const Expr* start = v->start();
|
||||
const Expr* stop = v->stop();
|
||||
Stmt* body = v->body();
|
||||
LoopOptions loop_options = v->loop_options();
|
||||
const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
|
||||
const Expr* start_new = start->accept_mutator(this);
|
||||
const Expr* stop_new = stop->accept_mutator(this);
|
||||
PushList();
|
||||
Stmt* body_new = body->accept_mutator(this);
|
||||
if (!body_new) {
|
||||
return nullptr;
|
||||
}
|
||||
Stmt* body_with_loads = AddMemLoadsFromList(body_new);
|
||||
PopList();
|
||||
if (var == var_new && start == start_new && stop == stop_new &&
|
||||
body == body_with_loads) {
|
||||
return (Stmt*)v;
|
||||
}
|
||||
return new For(var_new, start_new, stop_new, body_with_loads, loop_options);
|
||||
}
|
||||
|
||||
Stmt* mutate(const LetStmt* v) override {
|
||||
const Var* var = v->var();
|
||||
const Expr* value = v->value();
|
||||
Stmt* body = v->body();
|
||||
const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
|
||||
if (var_new == nullptr) {
|
||||
throw std::runtime_error("LetStmt var must be variable");
|
||||
}
|
||||
const Expr* value_new = value->accept_mutator(this);
|
||||
PushList();
|
||||
Stmt* body_new = body->accept_mutator(this);
|
||||
Stmt* body_with_loads = AddMemLoadsFromList(body_new);
|
||||
PopList();
|
||||
if (var == var_new && value == value_new && body == body_with_loads) {
|
||||
return (Stmt*)v;
|
||||
}
|
||||
return new LetStmt(var_new, value_new, body_with_loads);
|
||||
}
|
||||
|
||||
Stmt* mutate(const Cond* v) override {
|
||||
const Expr* cond_old = v->condition();
|
||||
Stmt* true_old = v->true_stmt();
|
||||
Stmt* false_old = v->false_stmt();
|
||||
|
||||
const Expr* cond_new = cond_old->accept_mutator(this);
|
||||
PushList();
|
||||
Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
|
||||
Stmt* true_with_loads = AddMemLoadsFromList(true_new);
|
||||
PopList();
|
||||
PushList();
|
||||
Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
|
||||
Stmt* false_with_loads = AddMemLoadsFromList(false_new);
|
||||
PopList();
|
||||
|
||||
if (cond_old == cond_new && true_old == true_with_loads &&
|
||||
false_old == false_with_loads) {
|
||||
return (Stmt*)v;
|
||||
}
|
||||
return new Cond(cond_new, true_with_loads, false_with_loads);
|
||||
}
|
||||
|
||||
const Expr* mutate(const IfThenElse* v) override {
|
||||
nested_if_then_else_++;
|
||||
const Expr* new_v = IRMutator::mutate(v);
|
||||
nested_if_then_else_--;
|
||||
return new_v;
|
||||
}
|
||||
|
||||
Stmt* Process(Stmt* stmt) {
|
||||
this->PushList();
|
||||
Stmt* stmt_v = stmt;
|
||||
Stmt* stmt_new = stmt_v->accept_mutator(this);
|
||||
Stmt* stmt_with_loads = AddMemLoadsFromList(stmt_new);
|
||||
this->PopList();
|
||||
return stmt_with_loads;
|
||||
}
|
||||
|
||||
private:
|
||||
using MemLoadEntry = std::pair<const Var*, const Expr*>;
|
||||
using MemLoadList = std::vector<MemLoadEntry>;
|
||||
using MemoryLoadStack = std::vector<MemLoadList>;
|
||||
|
||||
void PushList() {
|
||||
load_stack_.push_back(MemLoadList());
|
||||
}
|
||||
|
||||
void PopList() {
|
||||
load_stack_.pop_back();
|
||||
}
|
||||
|
||||
Stmt* AddMemLoadsFromList(Stmt* stmt) {
|
||||
MemLoadList& load_list = load_stack_.back();
|
||||
Stmt* stmt_v = stmt;
|
||||
for (auto iter = load_list.rbegin(); iter != load_list.rend(); iter++) {
|
||||
const MemLoadEntry& entry = *iter;
|
||||
const Var* var_ptr = entry.first;
|
||||
stmt_v = new LetStmt(var_ptr, entry.second, stmt_v);
|
||||
}
|
||||
return stmt_v;
|
||||
}
|
||||
|
||||
MemoryLoadStack load_stack_;
|
||||
// TODO: For now, we are not moving the loads with the IfThenElse.
|
||||
// Eventually, we should switch to a more generic structure like:
|
||||
// int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
|
||||
//
|
||||
// int v;
|
||||
// if (cond) {
|
||||
// v = true_v;
|
||||
// } else {
|
||||
// v = false_v;
|
||||
// }
|
||||
// int v2 = v + 2;
|
||||
int nested_if_then_else_ = 0;
|
||||
};
|
||||
|
||||
class HasRand : public IRVisitor {
|
||||
public:
|
||||
HasRand(Stmt* stmt) : stmt_(stmt) {
|
||||
stmt_->accept(this);
|
||||
}
|
||||
|
||||
bool has_rand() const {
|
||||
return has_rand_;
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Intrinsics* v) override {
|
||||
if (v->op_type() == IntrinsicsOp::kRand) {
|
||||
has_rand_ = true;
|
||||
} else {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
}
|
||||
Stmt* stmt_;
|
||||
bool has_rand_ = false;
|
||||
};
|
||||
|
||||
std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
|
||||
// We are using a global counter here to make sure difference instances within
|
||||
// CudaCodeGen have different names.
|
||||
static int64_t counter = 0;
|
||||
++counter;
|
||||
int64_t value = counter;
|
||||
return func_prefix + "_" + std::to_string(value);
|
||||
}
|
||||
|
||||
void CudaCodeGen::Initialize() {
|
||||
// TODO: handle multiple kernels.
|
||||
// TODO: handle dynamic dimension.
|
||||
// TODO: call nvrtc.
|
||||
HasRand has_rand_func(stmt());
|
||||
has_random_ = has_rand_func.has_rand();
|
||||
printer_ = std::make_unique<CudaPrinter>(&oss_, has_random_);
|
||||
|
||||
os() << "#define NAN __int_as_float(0x7fffffff)\n"
|
||||
"#define POS_INFINITY __int_as_float(0x7f800000)\n"
|
||||
"#define NEG_INFINITY __int_as_float(0xff800000)\n";
|
||||
if (has_random_) {
|
||||
os() << philox_random_string << std::endl;
|
||||
}
|
||||
|
||||
// Check whether the statement uses the Half type, if so add the
|
||||
// half_support_literal.
|
||||
CudaHalfChecker halfChecker;
|
||||
stmt()->accept(&halfChecker);
|
||||
if (halfChecker.hasHalf()) {
|
||||
os() << fuser::cuda::half_support_literal << std::endl;
|
||||
}
|
||||
|
||||
std::string func_name = GetUniqueFuncName("func");
|
||||
os() << "extern \"C\" __global__" << std::endl << "void " << func_name << "(";
|
||||
const std::vector<BufferArg> buffer_args = this->buffer_args();
|
||||
for (size_t i = 0; i < buffer_args.size(); i++) {
|
||||
if (i > 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
const BufferArg& buffer_arg = buffer_args[i];
|
||||
const Var* var = buffer_arg.var();
|
||||
Dtype dtype = buffer_arg.dtype();
|
||||
|
||||
os() << cudaDtypeCppString(dtype) << (buffer_arg.isVar() ? " " : "* ")
|
||||
<< name_manager()->get_unique_name(var);
|
||||
}
|
||||
const Var* rand_seed;
|
||||
const Var* rand_offset;
|
||||
if (has_random_) {
|
||||
// TODO: switch to kUint64 when it is available.
|
||||
rand_seed = new Var("rand_seed", kInt);
|
||||
rand_offset = new Var("rand_offset", kInt);
|
||||
std::string uint64_str = "unsigned long long";
|
||||
os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
|
||||
<< *rand_offset;
|
||||
}
|
||||
os() << ") {";
|
||||
os() << std::endl;
|
||||
|
||||
if (has_random_) {
|
||||
const Var* idx = new Var("idx", kInt);
|
||||
os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
|
||||
<< std::endl;
|
||||
const Var* rand_func = printer_->rand_func();
|
||||
os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
|
||||
<< *rand_offset << ");" << std::endl;
|
||||
os() << std::endl;
|
||||
}
|
||||
|
||||
Stmt* stmt_v = stmt();
|
||||
PrioritizeLoad prioritize_load;
|
||||
stmt_v = prioritize_load.Process(stmt_v);
|
||||
stmt_v->accept(printer_.get());
|
||||
os() << std::endl;
|
||||
os() << "}";
|
||||
|
||||
// Check that all block extents had been set.
|
||||
const std::vector<const Expr*>& gpu_block_extents =
|
||||
printer_->gpu_block_extents();
|
||||
const std::vector<const Expr*>& gpu_thread_extents =
|
||||
printer_->gpu_thread_extents();
|
||||
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
|
||||
if (!gpu_block_extents[i]) {
|
||||
throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
#if DEBUG_PRINT
|
||||
std::cout << "stmt: " << std::endl;
|
||||
std::cout << oss_.str() << std::endl;
|
||||
std::cout << "block(";
|
||||
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
|
||||
if (i > 0) {
|
||||
std::cout << ", ";
|
||||
}
|
||||
std::cout << *gpu_block_extents[i];
|
||||
}
|
||||
std::cout << "), thread(";
|
||||
for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
|
||||
if (i > 0) {
|
||||
std::cout << ", ";
|
||||
}
|
||||
std::cout << *gpu_thread_extents[i];
|
||||
}
|
||||
std::cout << ")" << std::endl;
|
||||
;
|
||||
#endif
|
||||
|
||||
CompileToNVRTC(oss_.str(), func_name);
|
||||
USE_TRIGGER(cuda_codegen_created);
|
||||
}
|
||||
|
||||
void CudaCodeGen::call(const std::vector<CallArg>& args) {
|
||||
CHECK_EQ(args.size(), buffer_args().size());
|
||||
|
||||
// TODO: move as much of this into the constructors.
|
||||
const std::vector<const Expr*>& gpu_block_extents =
|
||||
printer_->gpu_block_extents();
|
||||
const std::vector<const Expr*>& gpu_thread_extents =
|
||||
printer_->gpu_thread_extents();
|
||||
CHECK(gpu_block_extents.size() <= 3);
|
||||
CHECK(gpu_thread_extents.size() <= 3);
|
||||
std::vector<int> gpu_block_extents_v(3, 1);
|
||||
std::vector<int> gpu_thread_extents_v(3, 1);
|
||||
// evaluate all the block/thread extents into values
|
||||
// TODO: eventually, codegen these calculations and make them part of the
|
||||
// module.
|
||||
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
|
||||
ExprEval<SimpleIREvaluator> eval(
|
||||
ExprHandle(gpu_block_extents[i]), buffer_args());
|
||||
gpu_block_extents_v[i] = eval.value<int>(args);
|
||||
}
|
||||
for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
|
||||
ExprEval<SimpleIREvaluator> eval(
|
||||
ExprHandle(gpu_thread_extents[i]), buffer_args());
|
||||
gpu_thread_extents_v[i] = eval.value<int>(args);
|
||||
}
|
||||
|
||||
// Skip launching the kernel if there are no elements to process.
|
||||
for (int extent : gpu_block_extents_v) {
|
||||
if (extent == 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Bind the buffer addresses into arguments
|
||||
auto const& buffer_args = this->buffer_args();
|
||||
int ptr_count = buffer_args.size();
|
||||
if (has_random_) {
|
||||
ptr_count += 2;
|
||||
}
|
||||
std::vector<void*> args_data(buffer_args.size());
|
||||
std::vector<void*> ptr_to_args(ptr_count);
|
||||
uint64_t rand_seed = uint64_t(-1);
|
||||
uint64_t rand_offset = uint64_t(-1);
|
||||
for (size_t i = 0; i < buffer_args.size(); i++) {
|
||||
auto const& bufferArg = buffer_args[i];
|
||||
if (bufferArg.isVar()) {
|
||||
auto stype = bufferArg.dtype().scalar_type();
|
||||
switch (stype) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: \
|
||||
ptr_to_args[i] = args[i].Name##Ptr(); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled dtype in argument";
|
||||
}
|
||||
} else {
|
||||
args_data[i] = args[i].data();
|
||||
ptr_to_args[i] = &args_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (has_random_) {
|
||||
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
|
||||
// TODO: total hack. Switch to numel when it is available.
|
||||
int64_t total_elements_per_thread = (1LL << 28);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_engine_inputs =
|
||||
gen->philox_engine_inputs(total_elements_per_thread);
|
||||
rand_seed = philox_engine_inputs.first;
|
||||
rand_offset = philox_engine_inputs.second;
|
||||
}
|
||||
ptr_to_args[buffer_args.size()] = &rand_seed;
|
||||
ptr_to_args[buffer_args.size() + 1] = &rand_offset;
|
||||
}
|
||||
|
||||
// Launch the kernels
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
||||
function_,
|
||||
gpu_block_extents_v[0],
|
||||
gpu_block_extents_v[1],
|
||||
gpu_block_extents_v[2],
|
||||
gpu_thread_extents_v[0],
|
||||
gpu_thread_extents_v[1],
|
||||
gpu_thread_extents_v[2],
|
||||
0,
|
||||
stream,
|
||||
ptr_to_args.data(),
|
||||
nullptr));
|
||||
USE_TRIGGER(cuda_codegen_executed);
|
||||
}
|
||||
|
||||
void CudaCodeGen::CompileToNVRTC(
|
||||
const std::string& code,
|
||||
const std::string& func_name) {
|
||||
// Initializes driver's API context (if necessary)
|
||||
CUdevice device = 0;
|
||||
CUcontext pctx = 0;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
|
||||
// properly in some scenarios
|
||||
const auto prior_device = at::cuda::current_device();
|
||||
at::cuda::set_device(device);
|
||||
|
||||
// Acquires device and NVRTC properties (for compile arch and occupancy
|
||||
// calculations)
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
int major, minor;
|
||||
getMajorMinor(prop, major, minor);
|
||||
|
||||
#if DEBUG_PRINT
|
||||
std::cout << "major: " << major << ", "
|
||||
<< "minor: " << minor << std::endl;
|
||||
#endif
|
||||
|
||||
// Creates the NVRTC program
|
||||
nvrtcProgram program;
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
|
||||
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
std::vector<const char*> args = {};
|
||||
#else
|
||||
const std::string compute = "--gpu-architecture=compute_" +
|
||||
std::to_string(major) + std::to_string(minor);
|
||||
const std::vector<const char*> args = {
|
||||
"--std=c++14", compute.c_str(), "-default-device"};
|
||||
#endif
|
||||
|
||||
const auto result =
|
||||
nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
|
||||
if (result != NVRTC_SUCCESS) {
|
||||
size_t logsize;
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
|
||||
std::vector<char> log(logsize);
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
|
||||
std::stringstream cu;
|
||||
cu << log.data() << std::endl;
|
||||
cu << "nvrtc compilation failed: " << std::endl;
|
||||
cu << code << std::endl;
|
||||
throw std::runtime_error(cu.str());
|
||||
}
|
||||
ResourceGuard holdProgram(
|
||||
[&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
|
||||
AT_CUDA_NVRTC_CHECK(result);
|
||||
size_t ptx_size;
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
|
||||
std::vector<char> ptx;
|
||||
ptx.resize(ptx_size);
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data()));
|
||||
|
||||
CUmodule module;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
|
||||
}
|
||||
|
||||
RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
123
torch/csrc/jit/tensorexpr/cuda_codegen.h
Normal file
123
torch/csrc/jit/tensorexpr/cuda_codegen.h
Normal file
@ -0,0 +1,123 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/nvrtc_stub/ATenNVRTC.h"
|
||||
#include "c10/cuda/CUDACachingAllocator.h"
|
||||
#include "c10/cuda/CUDAGuard.h"
|
||||
#include "torch/csrc/jit/resource_guard.h"
|
||||
#include "torch/csrc/jit/tensorexpr/codegen.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_printer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
|
||||
#include "torch/csrc/jit/tensorexpr/unique_name_manager.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
// A class that overrides the underlying IRPrinter to produce Cuda C.
|
||||
class CudaPrinter : public IRPrinter {
|
||||
public:
|
||||
explicit CudaPrinter(std::ostream* os, bool has_random) : IRPrinter(*os) {
|
||||
if (has_random) {
|
||||
rand_func_ = new Var("rand", kHandle);
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Cast* v) override {
|
||||
auto dtype = v->dtype();
|
||||
if (dtype == kHalf) {
|
||||
os() << "half";
|
||||
} else {
|
||||
os() << dtype;
|
||||
}
|
||||
os() << "(";
|
||||
v->src_value()->accept(this);
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void visit(const Intrinsics* v) override;
|
||||
void visit(const For* v) override;
|
||||
|
||||
void visit(const Load* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const Max* v) override;
|
||||
void visit(const Min* v) override;
|
||||
void visit(const LetStmt* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
|
||||
const std::vector<const Expr*>& gpu_block_extents() const {
|
||||
return gpu_block_extents_;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& gpu_thread_extents() const {
|
||||
return gpu_thread_extents_;
|
||||
}
|
||||
|
||||
const Var* rand_func() const {
|
||||
return rand_func_;
|
||||
}
|
||||
|
||||
using IRPrinter::name_manager;
|
||||
using IRPrinter::visit;
|
||||
|
||||
private:
|
||||
std::vector<const Expr*> gpu_block_extents_;
|
||||
std::vector<const Expr*> gpu_thread_extents_;
|
||||
const Var* rand_func_;
|
||||
};
|
||||
|
||||
// Construct Cuda C from the buffer and tensor input, and invoke the kernel
|
||||
// when real arguments are provided.
|
||||
class TORCH_CUDA_API CudaCodeGen : public CodeGen {
|
||||
public:
|
||||
template <typename... Ts>
|
||||
CudaCodeGen(Stmt* stmt, Ts... ts) : CodeGen(stmt, std::forward<Ts>(ts)...) {
|
||||
Initialize();
|
||||
}
|
||||
|
||||
CudaCodeGen(Stmt* stmt, const std::vector<BufferArg>& buffer_args)
|
||||
: CodeGen(stmt, buffer_args) {
|
||||
Initialize();
|
||||
}
|
||||
|
||||
~CudaCodeGen() override {}
|
||||
|
||||
void call(const std::vector<CallArg>& args) override;
|
||||
|
||||
template <typename... Ts>
|
||||
void operator()(const Ts&... ts) {
|
||||
call(std::vector<CallArg>({CallArg(ts)...}));
|
||||
}
|
||||
|
||||
private:
|
||||
void Initialize();
|
||||
|
||||
void CompileToNVRTC(const std::string& code, const std::string& func_name);
|
||||
|
||||
UniqueNameManager* name_manager() {
|
||||
if (!printer_) {
|
||||
throw std::runtime_error("Null IRPrinter is not expected");
|
||||
}
|
||||
return printer_->name_manager();
|
||||
}
|
||||
|
||||
std::ostream& os() {
|
||||
return printer_->os();
|
||||
}
|
||||
|
||||
std::ostringstream oss_;
|
||||
std::unique_ptr<CudaPrinter> printer_;
|
||||
CUfunction function_;
|
||||
bool has_random_ = false;
|
||||
|
||||
std::string GetUniqueFuncName(const std::string& func_prefix);
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
30
torch/csrc/jit/tensorexpr/cuda_half_support.h
Normal file
30
torch/csrc/jit/tensorexpr/cuda_half_support.h
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/jit/codegen/fuser/cuda/resource_strings.h"
|
||||
#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
// Walk the Statment looking for Half size loads/stores.
|
||||
class CudaHalfChecker : public IRVisitor {
|
||||
public:
|
||||
bool hasHalf() {
|
||||
return hasHalf_;
|
||||
}
|
||||
|
||||
void visit(const Load* v) override {
|
||||
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
|
||||
}
|
||||
void visit(const Store* v) override {
|
||||
hasHalf_ |= v->value()->dtype().scalar_type() == ScalarType::Half;
|
||||
}
|
||||
|
||||
private:
|
||||
bool hasHalf_{false};
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
104
torch/csrc/jit/tensorexpr/cuda_random.h
Normal file
104
torch/csrc/jit/tensorexpr/cuda_random.h
Normal file
@ -0,0 +1,104 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
constexpr auto philox_random_string = R"(
|
||||
|
||||
class Philox {
|
||||
public:
|
||||
__device__ inline Philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset) {
|
||||
key.x = (unsigned int)seed;
|
||||
key.y = (unsigned int)(seed >> 32);
|
||||
counter = make_uint4(0, 0, 0, 0);
|
||||
counter.z = (unsigned int)(subsequence);
|
||||
counter.w = (unsigned int)(subsequence >> 32);
|
||||
STATE = 0;
|
||||
incr_n(offset / 4);
|
||||
}
|
||||
|
||||
__device__ inline unsigned long operator()() {
|
||||
if(STATE == 0) {
|
||||
uint4 counter_ = counter;
|
||||
uint2 key_ = key;
|
||||
for(int i = 0; i < 9; i++) {
|
||||
counter_ = single_round(counter_, key_);
|
||||
key_.x += (kPhilox10A); key_.y += (kPhilox10B);
|
||||
}
|
||||
output = single_round(counter_, key_);
|
||||
incr();
|
||||
}
|
||||
unsigned long ret;
|
||||
switch(STATE) {
|
||||
case 0: ret = output.x; break;
|
||||
case 1: ret = output.y; break;
|
||||
case 2: ret = output.z; break;
|
||||
case 3: ret = output.w; break;
|
||||
}
|
||||
STATE = (STATE + 1) % 4;
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
uint4 counter;
|
||||
uint4 output;
|
||||
uint2 key;
|
||||
unsigned int STATE;
|
||||
__device__ inline void incr_n(unsigned long long n) {
|
||||
unsigned int nlo = (unsigned int)(n);
|
||||
unsigned int nhi = (unsigned int)(n >> 32);
|
||||
counter.x += nlo;
|
||||
if (counter.x < nlo)
|
||||
nhi++;
|
||||
counter.y += nhi;
|
||||
if (nhi <= counter.y)
|
||||
return;
|
||||
if (++counter.z)
|
||||
return;
|
||||
++counter.w;
|
||||
}
|
||||
__device__ inline void incr() {
|
||||
if (++counter.x)
|
||||
return;
|
||||
if (++counter.y)
|
||||
return;
|
||||
if (++counter.z)
|
||||
return;
|
||||
++counter.w;
|
||||
}
|
||||
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
|
||||
unsigned int *result_high) {
|
||||
*result_high = __umulhi(a, b);
|
||||
return a*b;
|
||||
}
|
||||
|
||||
__device__ inline uint4 single_round(uint4 ctr, uint2 key) {
|
||||
unsigned int hi0;
|
||||
unsigned int hi1;
|
||||
unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
|
||||
unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
|
||||
|
||||
uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
// Inverse of 2^32.
|
||||
#define M_RAN_INVM32 2.3283064e-10f
|
||||
__device__ __inline__ float Uint32ToFloat(unsigned int x) {
|
||||
return x * M_RAN_INVM32;
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -5,6 +5,30 @@
|
||||
using namespace torch::jit;
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
static int te_cuda_pointwise_loop_levels = -1;
|
||||
static int te_cuda_pointwise_block_count = -1;
|
||||
static int te_cuda_pointwise_block_size = -1;
|
||||
|
||||
int& GetTECudaPointwiseLoopLevels() {
|
||||
return te_cuda_pointwise_loop_levels;
|
||||
}
|
||||
|
||||
int& GetTECudaPointwiseBlockCount() {
|
||||
return te_cuda_pointwise_block_count;
|
||||
}
|
||||
|
||||
int& GetTECudaPointwiseBlockSize() {
|
||||
return te_cuda_pointwise_block_size;
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
static at::ScalarType tensorType(Tensor* t) {
|
||||
return static_cast<at::ScalarType>(t->body()->dtype().scalar_type());
|
||||
}
|
||||
@ -883,12 +907,96 @@ Tensor* TensorExprKernel::ComputeValue(const torch::jit::Value* v) {
|
||||
void TensorExprKernel::LowerToBackend(BackendType backend_type) {
|
||||
std::vector<Tensor*> tensor_outputs(tensor_outputs_);
|
||||
|
||||
if (backend_type == BackendType::kCudaCodeGen) {
|
||||
for (size_t tensor_idx = 0; tensor_idx < tensor_outputs_.size();
|
||||
tensor_idx++) {
|
||||
Tensor* tensor = tensor_outputs_[tensor_idx];
|
||||
ExprHandle total_count = ExprHandle(tensor->dim(0));
|
||||
for (int i = 1; i < tensor->ndim(); i++) {
|
||||
const IntImm* total_count_i = total_count.AsNode<IntImm>();
|
||||
const IntImm* tensor_dim_i =
|
||||
dynamic_cast<const IntImm*>(tensor->dim(i));
|
||||
if (total_count_i && tensor_dim_i) {
|
||||
// TODO: switch to real constant folding when it is available.
|
||||
total_count =
|
||||
ExprHandle(total_count_i->value() * tensor_dim_i->value());
|
||||
} else {
|
||||
total_count = total_count * ExprHandle(tensor->dim(i));
|
||||
}
|
||||
}
|
||||
// Flatten the index for GPU kernels.
|
||||
// TODO: move this to fusing axis when it is ready.
|
||||
Tensor* new_out = Compute(
|
||||
tensor->func_var()->name_hint() + "_flat",
|
||||
{total_count},
|
||||
[tensor](const VarHandle& index) -> ExprHandle {
|
||||
std::vector<ExprHandle> dims;
|
||||
ExprHandle value = index;
|
||||
for (int i = tensor->ndim() - 1; i >= 0; i--) {
|
||||
ExprHandle idx = value;
|
||||
if (i > 0) {
|
||||
idx = Mod::make(value, ExprHandle(tensor->dim(i)));
|
||||
}
|
||||
dims.push_back(idx);
|
||||
value = value / ExprHandle(tensor->dim(i));
|
||||
}
|
||||
std::reverse(dims.begin(), dims.end());
|
||||
return tensor->call(dims);
|
||||
});
|
||||
tensor_outputs[tensor_idx] = new_out;
|
||||
}
|
||||
}
|
||||
|
||||
torch::jit::tensorexpr::schedule::LoopNest l(tensor_outputs);
|
||||
|
||||
// Compute non-output tensors_ inline
|
||||
for (auto& p : tensors_) {
|
||||
l.ComputeInline(l.getLoopBodyFor(p.second));
|
||||
}
|
||||
if (backend_type == kCudaCodeGen) {
|
||||
for (size_t i = 0; i < tensor_outputs_.size(); i++) {
|
||||
l.ComputeInline(l.getLoopBodyFor(tensor_outputs_[i]));
|
||||
|
||||
Tensor* tensor = tensor_outputs[i];
|
||||
const Var* index = tensor->arg(0);
|
||||
int loop_levels = GetTECudaPointwiseLoopLevels();
|
||||
const int kDefaultLoopLevels = 2;
|
||||
loop_levels = (loop_levels > 0) ? loop_levels : kDefaultLoopLevels;
|
||||
int block_count = GetTECudaPointwiseBlockCount();
|
||||
int block_size = GetTECudaPointwiseBlockSize();
|
||||
|
||||
if (loop_levels == 2) {
|
||||
Stmt* outer;
|
||||
Stmt* inner;
|
||||
const int kDefaultBlockSize = 512;
|
||||
if (block_size < 0) {
|
||||
block_size = kDefaultBlockSize;
|
||||
}
|
||||
std::vector<Stmt*> loops = l.getLoopStmtsFor(tensor);
|
||||
l.SplitWithMask(loops[0], block_size, &outer, &inner);
|
||||
l.SetGPUBlockIndex(outer, 0);
|
||||
l.SetGPUThreadIndex(inner, 0);
|
||||
} else if (loop_levels == 3) {
|
||||
Stmt* outer;
|
||||
Stmt* inner;
|
||||
Stmt* inner_1;
|
||||
Stmt* inner_2;
|
||||
// TODO: change the number of microprocessors
|
||||
const int kDefaultBlockCount = 1280;
|
||||
const int kDefaultBlockSize = 256;
|
||||
block_count = (block_count > 0) ? block_count : kDefaultBlockCount;
|
||||
block_size = (block_size > 0) ? block_size : kDefaultBlockSize;
|
||||
std::vector<Stmt*> loops = l.getLoopStmtsFor(tensor);
|
||||
l.SplitWithMask(loops[0], block_count * block_size, &outer, &inner);
|
||||
l.SplitWithMask(inner, block_size, &inner_1, &inner_2);
|
||||
l.SetGPUBlockIndex(inner_1, 0);
|
||||
l.SetGPUThreadIndex(inner_2, 0);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Invalid loop-level: " + std::to_string(loop_levels));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
l.ApplyInlines();
|
||||
Stmt* stmt = l.root_stmt();
|
||||
@ -911,6 +1019,9 @@ void TensorExprKernel::LowerToBackend(BackendType backend_type) {
|
||||
// Generate code.
|
||||
std::string codegen_name;
|
||||
switch (backend_type_) {
|
||||
case kCudaCodeGen:
|
||||
codegen_name = "cuda_codegen";
|
||||
break;
|
||||
case kSimpleIREval:
|
||||
codegen_name = "simple_ir_eval";
|
||||
break;
|
||||
@ -933,7 +1044,9 @@ void TensorExprKernel::PickAndCheckBackendType(
|
||||
throw std::runtime_error("No tensor inputs");
|
||||
}();
|
||||
BackendType backend_type = BackendType::kUninitialized;
|
||||
if (device.type() == at::kCPU) {
|
||||
if (device.type() == at::kCUDA) {
|
||||
backend_type = kCudaCodeGen;
|
||||
} else if (device.type() == at::kCPU) {
|
||||
backend_type = kSimpleIREval;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid device type");
|
||||
@ -956,6 +1069,7 @@ void TensorExprKernel::CodeGenRun(
|
||||
const std::vector<CodeGen::CallArg>& run_args) {
|
||||
switch (backend_type_) {
|
||||
case kSimpleIREval:
|
||||
case kCudaCodeGen:
|
||||
codegen_->call(run_args);
|
||||
break;
|
||||
default:
|
||||
|
@ -52,6 +52,7 @@ class TensorExprKernel {
|
||||
enum BackendType {
|
||||
kUninitialized,
|
||||
kSimpleIREval,
|
||||
kCudaCodeGen,
|
||||
};
|
||||
|
||||
ExprHandle constant(const torch::jit::Value* v);
|
||||
@ -205,6 +206,10 @@ class TensorExprKernel {
|
||||
at::Device device_ = at::kCPU;
|
||||
};
|
||||
|
||||
TORCH_API int& GetTECudaPointwiseLoopLevels();
|
||||
TORCH_API int& GetTECudaPointwiseBlockCount();
|
||||
TORCH_API int& GetTECudaPointwiseBlockSize();
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user