mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Follows #132209 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132411 Approved by: https://github.com/Skylion007
147 lines
4.2 KiB
C++
147 lines
4.2 KiB
C++
#pragma once
|
|
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/jit/resource_guard.h>
|
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
|
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
|
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
|
|
|
|
namespace torch::jit::tensorexpr {
|
|
|
|
// A class that analyzes the given program relevant for Block backend.
|
|
class BlockAnalysis : public IRVisitor {
|
|
public:
|
|
bool is_buf_store_target(const BufPtr& buf) const {
|
|
return store_targets_.count(buf) > 0;
|
|
}
|
|
|
|
const std::unordered_set<BufPtr>& loads() const {
|
|
return loads_;
|
|
}
|
|
|
|
const std::unordered_set<BufPtr>& stores() const {
|
|
return store_targets_;
|
|
}
|
|
|
|
int64_t block_size() const {
|
|
return block_size_;
|
|
}
|
|
|
|
bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const;
|
|
|
|
BufPtr getMultiDimBuf(const BufPtr& buf) const;
|
|
|
|
std::string getInputName(const BufPtr& buf) const;
|
|
|
|
std::string getFlatInputName(const BufPtr& buf) const {
|
|
return getInputName(buf) + "_flat";
|
|
}
|
|
|
|
std::unordered_map<std::string, BufPtr> getBufferMap() const {
|
|
return map_input_to_tensor_bufs_;
|
|
}
|
|
|
|
private:
|
|
void visit(const StorePtr& v) override;
|
|
void visit(const LoadPtr& v) override;
|
|
void visit(const ForPtr& v) override;
|
|
|
|
std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
|
|
std::unordered_set<BufPtr> store_targets_;
|
|
std::unordered_set<BufPtr> loads_;
|
|
int64_t block_size_ = 32;
|
|
};
|
|
|
|
// A class that overrides the underlying IRPrinter to produce Block.
|
|
class BlockPrinter : public IRPrinter {
|
|
public:
|
|
BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
|
|
: IRPrinter(*os), block_analysis_(block_analysis) {}
|
|
|
|
using IRPrinter::name_manager;
|
|
using IRPrinter::visit;
|
|
|
|
private:
|
|
BlockAnalysis* block_analysis_;
|
|
std::unordered_map<std::string, int> dim_values_map;
|
|
std::vector<std::string> dim_names = {"N", "H", "W", "C"};
|
|
std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
|
|
void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs);
|
|
void PrintArguments(const std::unordered_set<BufPtr>& bufs);
|
|
void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs);
|
|
void PrintDistribution(const std::unordered_set<BufPtr>& bufs);
|
|
void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true);
|
|
void PrintReshapeInfo(
|
|
const std::unordered_set<BufPtr>& bufs,
|
|
bool reverse = false);
|
|
void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
|
|
void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);
|
|
|
|
void visit(const ForPtr& v) override;
|
|
void visit(const LoadPtr& v) override;
|
|
void visit(const StorePtr& v) override;
|
|
void visit(const BlockPtr& v) override;
|
|
void visit(const AddPtr& v) override;
|
|
void visit(const MulPtr& v) override;
|
|
};
|
|
|
|
class TORCH_API BlockCodeGen : public CodeGen {
|
|
public:
|
|
template <typename... Ts>
|
|
/* implicit */
|
|
BlockCodeGen(StmtPtr stmt, Ts... ts)
|
|
: CodeGen(
|
|
stmt,
|
|
std::vector<BufferArg>({BufferArg(ts)...}),
|
|
at::Device(at::kCPU)) {
|
|
Initialize();
|
|
}
|
|
|
|
BlockCodeGen(
|
|
StmtPtr stmt,
|
|
const std::vector<BufferArg>& buffer_args,
|
|
at::Device device = at::Device(at::kCPU),
|
|
const std::string& kernel_func_name = "func")
|
|
: CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
|
|
Initialize();
|
|
}
|
|
|
|
~BlockCodeGen() override;
|
|
|
|
void call(const std::vector<CallArg>& args) override;
|
|
void call_raw(const std::vector<void*>& args) override;
|
|
|
|
void Initialize();
|
|
|
|
std::string getCodeText(const std::string& attr = "") override {
|
|
return oss_.str();
|
|
}
|
|
|
|
private:
|
|
UniqueNameManager* name_manager() {
|
|
if (!printer_) {
|
|
throw std::runtime_error("Null IRPrinter is not expected");
|
|
}
|
|
return printer_->name_manager();
|
|
}
|
|
|
|
std::ostream& os() {
|
|
return printer_->os();
|
|
}
|
|
|
|
std::ostringstream oss_;
|
|
std::unique_ptr<BlockPrinter> printer_;
|
|
std::unique_ptr<BlockAnalysis> block_analysis_;
|
|
|
|
std::string GetUniqueFuncName(const std::string& func_prefix);
|
|
};
|
|
} // namespace torch::jit::tensorexpr
|