Files
pytorch/torch/csrc/lazy/ts_backend/dynamic_ir.cpp
PyTorch MergeBot dbb55b448b Revert "[7/N] Fix Wextra-semi warning (#140225)"
This reverts commit ffb979032dc149b4c895526fe5b92d713ed7b1e1.

Reverted https://github.com/pytorch/pytorch/pull/140225 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/140225#issuecomment-2469312229))
2024-11-12 00:02:06 +00:00

117 lines
3.1 KiB
C++

#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include <utility>
static const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output) {
return dynamic_cast<const torch::lazy::DimensionNode*>(output.node);
}
namespace torch::lazy {
TSOpVector SizeNode::Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
TSLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve(2);
auto index = loctx->graph()->insertConstant(static_cast<int64_t>(this->dim_));
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(index);
torch::lazy::TSOpVector size_out =
torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
TORCH_CHECK_EQ(size_out.size(), 1);
return size_out;
}
SizeNode::SizeNode(Value input, size_t dim)
: TsNode(
OpKind{c10::Symbol::fromQualString("aten::size")},
{std::move(input)},
std::vector<Shape>{},
1,
MHash(dim)),
dim_(dim){};
int64_t SizeNode::getStaticValue() const {
return dynamic_cast<const TsNode*>(operand(0).node)
->shape(0)
.size(static_cast<int64_t>(dim_));
}
bool SizeNode::isSymbolic() const {
auto symbolic_vec =
dynamic_cast<const TsNode*>(operand(0).node)->shape(0).is_symbolic();
if (!symbolic_vec.has_value()) {
return true;
}
return symbolic_vec->at(dim_);
}
std::string SizeNode::ToString() const {
return "SizeNode";
}
SizeAdd::SizeAdd(Value a, Value b)
: TsNode(
OpKind{c10::Symbol::fromQualString("aten::add")},
{std::move(a), std::move(b)},
std::vector<Shape>{},
1){};
int64_t SizeAdd::getStaticValue() const {
return DimCast(operand(0))->getStaticValue() +
DimCast(operand(1))->getStaticValue();
}
bool SizeAdd::isSymbolic() const {
return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
}
std::string SizeAdd::ToString() const {
return "SizeAdd";
}
SizeMul::SizeMul(Value a, Value b)
: TsNode(
OpKind{c10::Symbol::fromQualString("aten::mul")},
{std::move(a), std::move(b)},
std::vector<Shape>{},
1){};
int64_t SizeMul::getStaticValue() const {
return DimCast(operand(0))->getStaticValue() *
DimCast(operand(1))->getStaticValue();
}
bool SizeMul::isSymbolic() const {
return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
}
std::string SizeMul::ToString() const {
return "SizeMul";
}
SizeDiv::SizeDiv(Value a, Value b)
: TsNode(
OpKind{c10::Symbol::fromQualString("aten::div")},
{std::move(a), std::move(b)},
std::vector<Shape>{},
1){};
int64_t SizeDiv::getStaticValue() const {
TORCH_CHECK(
DimCast(operand(1))->getStaticValue() != 0,
"Can't divide a dimension by zero");
return DimCast(operand(0))->getStaticValue() /
DimCast(operand(1))->getStaticValue();
}
bool SizeDiv::isSymbolic() const {
return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
}
std::string SizeDiv::ToString() const {
return "SizeDiv";
}
} // namespace torch::lazy