mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom_op] support default dtype values (#129189)
This PR: - moves some of the dtype-string utilities into ScalarType.{h, cpp} - adds a new utility to get a mapping from dtype name to the C++ dtype - the perser now checks if the string is a dtype name; if it is then it pulls the c++ dtype from the mapping. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/129189 Approved by: https://github.com/albanD ghstack dependencies: #129177, #129178, #129179
This commit is contained in:
@ -128,4 +128,112 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
return _promoteTypesLookup[ix_a][ix_b];
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
|
||||
switch (scalarType) {
|
||||
case c10::ScalarType::UInt1:
|
||||
return std::make_pair("uint1", "bit");
|
||||
case c10::ScalarType::UInt2:
|
||||
return std::make_pair("uint2", "");
|
||||
case c10::ScalarType::UInt3:
|
||||
return std::make_pair("uint3", "");
|
||||
case c10::ScalarType::UInt4:
|
||||
return std::make_pair("uint4", "");
|
||||
case c10::ScalarType::UInt5:
|
||||
return std::make_pair("uint5", "");
|
||||
case c10::ScalarType::UInt6:
|
||||
return std::make_pair("uint6", "");
|
||||
case c10::ScalarType::UInt7:
|
||||
return std::make_pair("uint7", "");
|
||||
case c10::ScalarType::Byte:
|
||||
// no "byte" because byte is signed in numpy and we overload
|
||||
// byte to mean bool often
|
||||
return std::make_pair("uint8", "");
|
||||
case c10::ScalarType::UInt16:
|
||||
return std::make_pair("uint16", "");
|
||||
case c10::ScalarType::UInt32:
|
||||
return std::make_pair("uint32", "");
|
||||
case c10::ScalarType::UInt64:
|
||||
return std::make_pair("uint64", "");
|
||||
case c10::ScalarType::Char:
|
||||
// no "char" because it is not consistently signed or unsigned; we want
|
||||
// to move to int8
|
||||
return std::make_pair("int8", "");
|
||||
case c10::ScalarType::Double:
|
||||
return std::make_pair("float64", "double");
|
||||
case c10::ScalarType::Float:
|
||||
return std::make_pair("float32", "float");
|
||||
case c10::ScalarType::Int:
|
||||
return std::make_pair("int32", "int");
|
||||
case c10::ScalarType::Long:
|
||||
return std::make_pair("int64", "long");
|
||||
case c10::ScalarType::Short:
|
||||
return std::make_pair("int16", "short");
|
||||
case c10::ScalarType::Half:
|
||||
return std::make_pair("float16", "half");
|
||||
case c10::ScalarType::ComplexHalf:
|
||||
return std::make_pair("complex32", "chalf");
|
||||
case c10::ScalarType::ComplexFloat:
|
||||
return std::make_pair("complex64", "cfloat");
|
||||
case c10::ScalarType::ComplexDouble:
|
||||
return std::make_pair("complex128", "cdouble");
|
||||
case c10::ScalarType::Bool:
|
||||
return std::make_pair("bool", "");
|
||||
case c10::ScalarType::QInt8:
|
||||
return std::make_pair("qint8", "");
|
||||
case c10::ScalarType::QUInt8:
|
||||
return std::make_pair("quint8", "");
|
||||
case c10::ScalarType::QInt32:
|
||||
return std::make_pair("qint32", "");
|
||||
case c10::ScalarType::BFloat16:
|
||||
return std::make_pair("bfloat16", "");
|
||||
case c10::ScalarType::QUInt4x2:
|
||||
return std::make_pair("quint4x2", "");
|
||||
case c10::ScalarType::QUInt2x4:
|
||||
return std::make_pair("quint2x4", "");
|
||||
case c10::ScalarType::Bits1x8:
|
||||
return std::make_pair("bits1x8", "");
|
||||
case c10::ScalarType::Bits2x4:
|
||||
return std::make_pair("bits2x4", "");
|
||||
case c10::ScalarType::Bits4x2:
|
||||
return std::make_pair("bits4x2", "");
|
||||
case c10::ScalarType::Bits8:
|
||||
return std::make_pair("bits8", "");
|
||||
case c10::ScalarType::Bits16:
|
||||
return std::make_pair("bits16", "");
|
||||
case c10::ScalarType::Float8_e5m2:
|
||||
return std::make_pair("float8_e5m2", "");
|
||||
case c10::ScalarType::Float8_e4m3fn:
|
||||
return std::make_pair("float8_e4m3fn", "");
|
||||
case c10::ScalarType::Float8_e5m2fnuz:
|
||||
return std::make_pair("float8_e5m2fnuz", "");
|
||||
case c10::ScalarType::Float8_e4m3fnuz:
|
||||
return std::make_pair("float8_e4m3fnuz", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap() {
|
||||
static std::unordered_map<std::string, ScalarType> result;
|
||||
if (!result.empty()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
#define DEFINE_SCALAR_TYPE(_1, n) c10::ScalarType::n,
|
||||
|
||||
auto all_scalar_types = {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
|
||||
|
||||
#undef DEFINE_SCALAR_TYPE
|
||||
|
||||
for (auto scalar_type : all_scalar_types) {
|
||||
auto names = getDtypeNames(scalar_type);
|
||||
result[std::get<0>(names)] = scalar_type;
|
||||
if (!std::get<1>(names).empty()) {
|
||||
result[std::get<1>(names)] = scalar_type;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include <limits>
|
||||
#include <ostream>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -561,4 +562,12 @@ inline std::ostream& operator<<(
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
c10::ScalarType scalarType);
|
||||
|
||||
// Returns a map of string name to dtype.
|
||||
C10_API const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap();
|
||||
|
||||
} // namespace c10
|
||||
|
@ -2412,13 +2412,19 @@ class TestCustomOpAPI(TestCase):
|
||||
c: bool = True,
|
||||
d: int = 3,
|
||||
e: str = "foo",
|
||||
f: torch.dtype = torch.float,
|
||||
g: torch.dtype = torch.float32,
|
||||
h: torch.dtype = torch.int,
|
||||
) -> Tensor:
|
||||
defaults.extend([a, b, c, d, e])
|
||||
defaults.extend([a, b, c, d, e, f, g, h])
|
||||
return x.clone()
|
||||
|
||||
x = torch.randn(3)
|
||||
f(x)
|
||||
self.assertEqual(defaults, [None, 3.14, True, 3, "foo"])
|
||||
self.assertEqual(
|
||||
defaults,
|
||||
[None, 3.14, True, 3, "foo", torch.float, torch.float32, torch.int],
|
||||
)
|
||||
|
||||
def test_mutated_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -79,6 +79,11 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
default_repr = str(param.default)
|
||||
elif isinstance(param.default, str):
|
||||
default_repr = f'"{param.default}"'
|
||||
elif isinstance(param.default, torch.dtype):
|
||||
dtype_repr = str(param.default)
|
||||
torch_dot = "torch."
|
||||
assert dtype_repr.startswith(torch_dot)
|
||||
default_repr = dtype_repr[len(torch_dot) :]
|
||||
else:
|
||||
error_fn(
|
||||
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
|
||||
|
@ -194,7 +194,7 @@ static PyObject* THPIInfo_min(THPIInfo* self, void*) {
|
||||
|
||||
static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto primary_name = torch::utils::getDtypeNames(self->type).first;
|
||||
auto primary_name = c10::getDtypeNames(self->type).first;
|
||||
return AT_DISPATCH_IINFO_TYPES(self->type, "dtype", [&primary_name] {
|
||||
return PyUnicode_FromString(primary_name.data());
|
||||
});
|
||||
@ -227,7 +227,7 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
|
||||
|
||||
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto primary_name = torch::utils::getDtypeNames(self->type).first;
|
||||
auto primary_name = c10::getDtypeNames(self->type).first;
|
||||
return _AT_DISPATCH_FINFO_TYPES(self->type, "dtype", [&primary_name] {
|
||||
return PyUnicode_FromString(primary_name.data());
|
||||
});
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
@ -185,7 +186,8 @@ struct SchemaParser {
|
||||
name = L.expect(TK_IDENT).text();
|
||||
if (L.nextIf('=')) {
|
||||
// NB: this means we have to unswizzle default too
|
||||
default_value = parseDefaultValue(*fake_type, fake_type->kind(), N);
|
||||
default_value =
|
||||
parseDefaultValue(*fake_type, fake_type->kind(), *real_type, N);
|
||||
}
|
||||
}
|
||||
return Argument(
|
||||
@ -197,11 +199,29 @@ struct SchemaParser {
|
||||
!is_return && kwarg_only,
|
||||
std::move(alias_info));
|
||||
}
|
||||
IValue parseSingleConstant(const c10::Type& type, TypeKind kind) {
|
||||
|
||||
bool isPossiblyOptionalScalarType(const c10::Type& type) {
|
||||
if (type.kind() == at::ScalarTypeType::Kind) {
|
||||
return true;
|
||||
}
|
||||
if (type.kind() == at::OptionalType::Kind) {
|
||||
for (const auto& inner : type.containedTypes()) {
|
||||
if (isPossiblyOptionalScalarType(*inner))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
IValue parseSingleConstant(
|
||||
const c10::Type& type,
|
||||
TypeKind kind,
|
||||
const c10::Type& real_type) {
|
||||
if (kind == c10::TypeKind::DynamicType) {
|
||||
return parseSingleConstant(
|
||||
type, type.expectRef<c10::DynamicType>().dynamicKind());
|
||||
type, type.expectRef<c10::DynamicType>().dynamicKind(), real_type);
|
||||
}
|
||||
const auto& str2dtype = c10::getStringToDtypeMap();
|
||||
switch (L.cur().kind) {
|
||||
case TK_TRUE:
|
||||
L.next();
|
||||
@ -219,6 +239,9 @@ struct SchemaParser {
|
||||
case TK_IDENT: {
|
||||
auto tok = L.next();
|
||||
auto text = tok.text();
|
||||
// NB: float/complex/long are here for BC purposes. Other dtypes
|
||||
// are handled via str2dtype.
|
||||
// Please don't add more cases to this if-else block.
|
||||
if ("float" == text) {
|
||||
return static_cast<int64_t>(at::kFloat);
|
||||
} else if ("complex" == text) {
|
||||
@ -231,6 +254,10 @@ struct SchemaParser {
|
||||
return static_cast<int64_t>(at::Reduction::Mean);
|
||||
} else if ("contiguous_format" == text) {
|
||||
return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
|
||||
} else if (
|
||||
isPossiblyOptionalScalarType(real_type) &&
|
||||
str2dtype.count(text) > 0) {
|
||||
return static_cast<int64_t>(str2dtype.at(text));
|
||||
} else {
|
||||
throw ErrorReport(L.cur().range) << "invalid numeric default value";
|
||||
}
|
||||
@ -277,12 +304,15 @@ struct SchemaParser {
|
||||
<< "lists are only supported for float, int and complex types";
|
||||
}
|
||||
}
|
||||
IValue parseConstantList(const c10::Type& type, TypeKind kind) {
|
||||
IValue parseConstantList(
|
||||
const c10::Type& type,
|
||||
TypeKind kind,
|
||||
const c10::Type& real_type) {
|
||||
auto tok = L.expect('[');
|
||||
std::vector<IValue> vs;
|
||||
if (L.cur().kind != ']') {
|
||||
do {
|
||||
vs.push_back(parseSingleConstant(type, kind));
|
||||
vs.push_back(parseSingleConstant(type, kind, real_type));
|
||||
} while (L.nextIf(','));
|
||||
}
|
||||
L.expect(']');
|
||||
@ -296,6 +326,7 @@ struct SchemaParser {
|
||||
IValue parseDefaultValue(
|
||||
const c10::Type& arg_type,
|
||||
TypeKind kind,
|
||||
const c10::Type& real_type,
|
||||
std::optional<int32_t> arg_N) {
|
||||
auto range = L.cur().range;
|
||||
switch (kind) {
|
||||
@ -311,7 +342,7 @@ struct SchemaParser {
|
||||
case TypeKind::BoolType:
|
||||
case TypeKind::FloatType:
|
||||
case TypeKind::ComplexType:
|
||||
return parseSingleConstant(arg_type, kind);
|
||||
return parseSingleConstant(arg_type, kind, real_type);
|
||||
break;
|
||||
case TypeKind::DeviceObjType: {
|
||||
auto device_text =
|
||||
@ -321,20 +352,24 @@ struct SchemaParser {
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
auto elem_type = arg_type.containedType(0);
|
||||
auto real_elem_type = real_type.containedType(0);
|
||||
if (L.cur().kind == TK_IDENT) {
|
||||
return parseTensorDefault(range);
|
||||
} else if (arg_N && L.cur().kind != '[') {
|
||||
IValue v = parseSingleConstant(*elem_type, elem_type->kind());
|
||||
IValue v = parseSingleConstant(
|
||||
*elem_type, elem_type->kind(), *real_elem_type);
|
||||
std::vector<IValue> repeated(*arg_N, v);
|
||||
return convertToList(*elem_type, elem_type->kind(), range, repeated);
|
||||
} else {
|
||||
return parseConstantList(*elem_type, elem_type->kind());
|
||||
return parseConstantList(
|
||||
*elem_type, elem_type->kind(), *real_elem_type);
|
||||
}
|
||||
} break;
|
||||
case TypeKind::DynamicType:
|
||||
return parseDefaultValue(
|
||||
arg_type,
|
||||
arg_type.expectRef<c10::DynamicType>().dynamicKind(),
|
||||
real_type,
|
||||
arg_N);
|
||||
default:
|
||||
throw ErrorReport(range) << "unexpected type, file a bug report";
|
||||
|
@ -168,6 +168,7 @@ std::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
|
||||
static std::unordered_map<std::string, at::ScalarType> type_map = {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
|
||||
|
||||
#undef DEFINE_SCALAR_TYPE
|
||||
auto type = type_map.find(dtype);
|
||||
if (type != type_map.end()) {
|
||||
return type->second;
|
||||
|
@ -7,91 +7,6 @@
|
||||
|
||||
namespace torch::utils {
|
||||
|
||||
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
|
||||
switch (scalarType) {
|
||||
case at::ScalarType::UInt1:
|
||||
return std::make_pair("uint1", "bit");
|
||||
case at::ScalarType::UInt2:
|
||||
return std::make_pair("uint2", "");
|
||||
case at::ScalarType::UInt3:
|
||||
return std::make_pair("uint3", "");
|
||||
case at::ScalarType::UInt4:
|
||||
return std::make_pair("uint4", "");
|
||||
case at::ScalarType::UInt5:
|
||||
return std::make_pair("uint5", "");
|
||||
case at::ScalarType::UInt6:
|
||||
return std::make_pair("uint6", "");
|
||||
case at::ScalarType::UInt7:
|
||||
return std::make_pair("uint7", "");
|
||||
case at::ScalarType::Byte:
|
||||
// no "byte" because byte is signed in numpy and we overload
|
||||
// byte to mean bool often
|
||||
return std::make_pair("uint8", "");
|
||||
case at::ScalarType::UInt16:
|
||||
return std::make_pair("uint16", "");
|
||||
case at::ScalarType::UInt32:
|
||||
return std::make_pair("uint32", "");
|
||||
case at::ScalarType::UInt64:
|
||||
return std::make_pair("uint64", "");
|
||||
case at::ScalarType::Char:
|
||||
// no "char" because it is not consistently signed or unsigned; we want
|
||||
// to move to int8
|
||||
return std::make_pair("int8", "");
|
||||
case at::ScalarType::Double:
|
||||
return std::make_pair("float64", "double");
|
||||
case at::ScalarType::Float:
|
||||
return std::make_pair("float32", "float");
|
||||
case at::ScalarType::Int:
|
||||
return std::make_pair("int32", "int");
|
||||
case at::ScalarType::Long:
|
||||
return std::make_pair("int64", "long");
|
||||
case at::ScalarType::Short:
|
||||
return std::make_pair("int16", "short");
|
||||
case at::ScalarType::Half:
|
||||
return std::make_pair("float16", "half");
|
||||
case at::ScalarType::ComplexHalf:
|
||||
return std::make_pair("complex32", "chalf");
|
||||
case at::ScalarType::ComplexFloat:
|
||||
return std::make_pair("complex64", "cfloat");
|
||||
case at::ScalarType::ComplexDouble:
|
||||
return std::make_pair("complex128", "cdouble");
|
||||
case at::ScalarType::Bool:
|
||||
return std::make_pair("bool", "");
|
||||
case at::ScalarType::QInt8:
|
||||
return std::make_pair("qint8", "");
|
||||
case at::ScalarType::QUInt8:
|
||||
return std::make_pair("quint8", "");
|
||||
case at::ScalarType::QInt32:
|
||||
return std::make_pair("qint32", "");
|
||||
case at::ScalarType::BFloat16:
|
||||
return std::make_pair("bfloat16", "");
|
||||
case at::ScalarType::QUInt4x2:
|
||||
return std::make_pair("quint4x2", "");
|
||||
case at::ScalarType::QUInt2x4:
|
||||
return std::make_pair("quint2x4", "");
|
||||
case at::ScalarType::Bits1x8:
|
||||
return std::make_pair("bits1x8", "");
|
||||
case at::ScalarType::Bits2x4:
|
||||
return std::make_pair("bits2x4", "");
|
||||
case at::ScalarType::Bits4x2:
|
||||
return std::make_pair("bits4x2", "");
|
||||
case at::ScalarType::Bits8:
|
||||
return std::make_pair("bits8", "");
|
||||
case at::ScalarType::Bits16:
|
||||
return std::make_pair("bits16", "");
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return std::make_pair("float8_e5m2", "");
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return std::make_pair("float8_e4m3fn", "");
|
||||
case at::ScalarType::Float8_e5m2fnuz:
|
||||
return std::make_pair("float8_e5m2fnuz", "");
|
||||
case at::ScalarType::Float8_e4m3fnuz:
|
||||
return std::make_pair("float8_e4m3fnuz", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
}
|
||||
|
||||
void initializeDtypes() {
|
||||
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
|
||||
if (!torch_module)
|
||||
@ -102,8 +17,10 @@ void initializeDtypes() {
|
||||
auto all_scalar_types = {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
|
||||
|
||||
#undef DEFINE_SCALAR_TYPE
|
||||
|
||||
for (at::ScalarType scalarType : all_scalar_types) {
|
||||
auto [primary_name, legacy_name] = getDtypeNames(scalarType);
|
||||
auto [primary_name, legacy_name] = c10::getDtypeNames(scalarType);
|
||||
PyObject* dtype = THPDtype_New(scalarType, primary_name);
|
||||
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
|
||||
Py_INCREF(dtype);
|
||||
|
Reference in New Issue
Block a user