mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Re-land https://github.com/pytorch/pytorch/pull/76711 by fixing internal build errors. Generate class-level opkind as a static method instead of a static member. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77102 Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/antoniojkim
41 lines
896 B
C++
41 lines
896 B
C++
#include <torch/csrc/lazy/ts_backend/ops/scalar.h>
|
|
|
|
#include <functional>
|
|
#include <sstream>
|
|
|
|
#include <ATen/core/Formatting.h>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
using at::operator<<;
|
|
|
|
Scalar::Scalar(const at::Scalar& value, Shape shape)
|
|
: TsNode(
|
|
ClassOpKind(),
|
|
std::move(shape),
|
|
/*num_outputs=*/1,
|
|
ScalarHash(value)),
|
|
value_(value) {}
|
|
|
|
Scalar::Scalar(const at::Scalar& value, c10::ScalarType type)
|
|
: TsNode(
|
|
ClassOpKind(),
|
|
{Shape(type, {})},
|
|
/*num_outputs=*/1,
|
|
ScalarHash(value)),
|
|
value_(value) {}
|
|
|
|
std::string Scalar::ToString() const {
|
|
std::stringstream ss;
|
|
ss << TsNode::ToString() << ", value=" << value_;
|
|
return ss.str();
|
|
}
|
|
|
|
hash_t ScalarHash(const at::Scalar& s) {
|
|
return s.isFloatingPoint() ? Hash(s.toDouble()) : Hash(s.toLong());
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|