[custom_op] support default dtype values (#129189)

This PR:
- moves some of the dtype-string utilities into ScalarType.{h, cpp}
- adds a new utility to get a mapping from dtype name to the C++ dtype
- the perser now checks if the string is a dtype name; if it is then it
  pulls the c++ dtype from the mapping.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129189
Approved by: https://github.com/albanD
ghstack dependencies: #129177, #129178, #129179
This commit is contained in:
rzou
2024-06-21 09:34:19 -07:00
committed by PyTorch MergeBot
parent 3e02ecd740
commit 856541c701
8 changed files with 179 additions and 98 deletions

View File

@ -128,4 +128,112 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
return _promoteTypesLookup[ix_a][ix_b];
}
std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
switch (scalarType) {
case c10::ScalarType::UInt1:
return std::make_pair("uint1", "bit");
case c10::ScalarType::UInt2:
return std::make_pair("uint2", "");
case c10::ScalarType::UInt3:
return std::make_pair("uint3", "");
case c10::ScalarType::UInt4:
return std::make_pair("uint4", "");
case c10::ScalarType::UInt5:
return std::make_pair("uint5", "");
case c10::ScalarType::UInt6:
return std::make_pair("uint6", "");
case c10::ScalarType::UInt7:
return std::make_pair("uint7", "");
case c10::ScalarType::Byte:
// no "byte" because byte is signed in numpy and we overload
// byte to mean bool often
return std::make_pair("uint8", "");
case c10::ScalarType::UInt16:
return std::make_pair("uint16", "");
case c10::ScalarType::UInt32:
return std::make_pair("uint32", "");
case c10::ScalarType::UInt64:
return std::make_pair("uint64", "");
case c10::ScalarType::Char:
// no "char" because it is not consistently signed or unsigned; we want
// to move to int8
return std::make_pair("int8", "");
case c10::ScalarType::Double:
return std::make_pair("float64", "double");
case c10::ScalarType::Float:
return std::make_pair("float32", "float");
case c10::ScalarType::Int:
return std::make_pair("int32", "int");
case c10::ScalarType::Long:
return std::make_pair("int64", "long");
case c10::ScalarType::Short:
return std::make_pair("int16", "short");
case c10::ScalarType::Half:
return std::make_pair("float16", "half");
case c10::ScalarType::ComplexHalf:
return std::make_pair("complex32", "chalf");
case c10::ScalarType::ComplexFloat:
return std::make_pair("complex64", "cfloat");
case c10::ScalarType::ComplexDouble:
return std::make_pair("complex128", "cdouble");
case c10::ScalarType::Bool:
return std::make_pair("bool", "");
case c10::ScalarType::QInt8:
return std::make_pair("qint8", "");
case c10::ScalarType::QUInt8:
return std::make_pair("quint8", "");
case c10::ScalarType::QInt32:
return std::make_pair("qint32", "");
case c10::ScalarType::BFloat16:
return std::make_pair("bfloat16", "");
case c10::ScalarType::QUInt4x2:
return std::make_pair("quint4x2", "");
case c10::ScalarType::QUInt2x4:
return std::make_pair("quint2x4", "");
case c10::ScalarType::Bits1x8:
return std::make_pair("bits1x8", "");
case c10::ScalarType::Bits2x4:
return std::make_pair("bits2x4", "");
case c10::ScalarType::Bits4x2:
return std::make_pair("bits4x2", "");
case c10::ScalarType::Bits8:
return std::make_pair("bits8", "");
case c10::ScalarType::Bits16:
return std::make_pair("bits16", "");
case c10::ScalarType::Float8_e5m2:
return std::make_pair("float8_e5m2", "");
case c10::ScalarType::Float8_e4m3fn:
return std::make_pair("float8_e4m3fn", "");
case c10::ScalarType::Float8_e5m2fnuz:
return std::make_pair("float8_e5m2fnuz", "");
case c10::ScalarType::Float8_e4m3fnuz:
return std::make_pair("float8_e4m3fnuz", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}
}
const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap() {
static std::unordered_map<std::string, ScalarType> result;
if (!result.empty()) {
return result;
}
#define DEFINE_SCALAR_TYPE(_1, n) c10::ScalarType::n,
auto all_scalar_types = {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
#undef DEFINE_SCALAR_TYPE
for (auto scalar_type : all_scalar_types) {
auto names = getDtypeNames(scalar_type);
result[std::get<0>(names)] = scalar_type;
if (!std::get<1>(names).empty()) {
result[std::get<1>(names)] = scalar_type;
}
}
return result;
}
} // namespace c10

View File

@ -22,6 +22,7 @@
#include <limits>
#include <ostream>
#include <type_traits>
#include <unordered_map>
namespace c10 {
@ -561,4 +562,12 @@ inline std::ostream& operator<<(
return stream << toString(scalar_type);
}
// Returns a pair of strings representing the names for each dtype.
// The returned pair is (name, legacy_name_if_applicable)
C10_API std::pair<std::string, std::string> getDtypeNames(
c10::ScalarType scalarType);
// Returns a map of string name to dtype.
C10_API const std::unordered_map<std::string, ScalarType>& getStringToDtypeMap();
} // namespace c10

View File

@ -2412,13 +2412,19 @@ class TestCustomOpAPI(TestCase):
c: bool = True,
d: int = 3,
e: str = "foo",
f: torch.dtype = torch.float,
g: torch.dtype = torch.float32,
h: torch.dtype = torch.int,
) -> Tensor:
defaults.extend([a, b, c, d, e])
defaults.extend([a, b, c, d, e, f, g, h])
return x.clone()
x = torch.randn(3)
f(x)
self.assertEqual(defaults, [None, 3.14, True, 3, "foo"])
self.assertEqual(
defaults,
[None, 3.14, True, 3, "foo", torch.float, torch.float32, torch.int],
)
def test_mutated_error(self):
with self.assertRaisesRegex(

View File

@ -79,6 +79,11 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
default_repr = str(param.default)
elif isinstance(param.default, str):
default_repr = f'"{param.default}"'
elif isinstance(param.default, torch.dtype):
dtype_repr = str(param.default)
torch_dot = "torch."
assert dtype_repr.startswith(torch_dot)
default_repr = dtype_repr[len(torch_dot) :]
else:
error_fn(
f"Parameter {name} has an unsupported default value type {type(param.default)}. "

View File

@ -194,7 +194,7 @@ static PyObject* THPIInfo_min(THPIInfo* self, void*) {
static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
HANDLE_TH_ERRORS
auto primary_name = torch::utils::getDtypeNames(self->type).first;
auto primary_name = c10::getDtypeNames(self->type).first;
return AT_DISPATCH_IINFO_TYPES(self->type, "dtype", [&primary_name] {
return PyUnicode_FromString(primary_name.data());
});
@ -227,7 +227,7 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
auto primary_name = torch::utils::getDtypeNames(self->type).first;
auto primary_name = c10::getDtypeNames(self->type).first;
return _AT_DISPATCH_FINFO_TYPES(self->type, "dtype", [&primary_name] {
return PyUnicode_FromString(primary_name.data());
});

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/frontend/lexer.h>
@ -185,7 +186,8 @@ struct SchemaParser {
name = L.expect(TK_IDENT).text();
if (L.nextIf('=')) {
// NB: this means we have to unswizzle default too
default_value = parseDefaultValue(*fake_type, fake_type->kind(), N);
default_value =
parseDefaultValue(*fake_type, fake_type->kind(), *real_type, N);
}
}
return Argument(
@ -197,11 +199,29 @@ struct SchemaParser {
!is_return && kwarg_only,
std::move(alias_info));
}
IValue parseSingleConstant(const c10::Type& type, TypeKind kind) {
bool isPossiblyOptionalScalarType(const c10::Type& type) {
if (type.kind() == at::ScalarTypeType::Kind) {
return true;
}
if (type.kind() == at::OptionalType::Kind) {
for (const auto& inner : type.containedTypes()) {
if (isPossiblyOptionalScalarType(*inner))
return true;
}
}
return false;
}
IValue parseSingleConstant(
const c10::Type& type,
TypeKind kind,
const c10::Type& real_type) {
if (kind == c10::TypeKind::DynamicType) {
return parseSingleConstant(
type, type.expectRef<c10::DynamicType>().dynamicKind());
type, type.expectRef<c10::DynamicType>().dynamicKind(), real_type);
}
const auto& str2dtype = c10::getStringToDtypeMap();
switch (L.cur().kind) {
case TK_TRUE:
L.next();
@ -219,6 +239,9 @@ struct SchemaParser {
case TK_IDENT: {
auto tok = L.next();
auto text = tok.text();
// NB: float/complex/long are here for BC purposes. Other dtypes
// are handled via str2dtype.
// Please don't add more cases to this if-else block.
if ("float" == text) {
return static_cast<int64_t>(at::kFloat);
} else if ("complex" == text) {
@ -231,6 +254,10 @@ struct SchemaParser {
return static_cast<int64_t>(at::Reduction::Mean);
} else if ("contiguous_format" == text) {
return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
} else if (
isPossiblyOptionalScalarType(real_type) &&
str2dtype.count(text) > 0) {
return static_cast<int64_t>(str2dtype.at(text));
} else {
throw ErrorReport(L.cur().range) << "invalid numeric default value";
}
@ -277,12 +304,15 @@ struct SchemaParser {
<< "lists are only supported for float, int and complex types";
}
}
IValue parseConstantList(const c10::Type& type, TypeKind kind) {
IValue parseConstantList(
const c10::Type& type,
TypeKind kind,
const c10::Type& real_type) {
auto tok = L.expect('[');
std::vector<IValue> vs;
if (L.cur().kind != ']') {
do {
vs.push_back(parseSingleConstant(type, kind));
vs.push_back(parseSingleConstant(type, kind, real_type));
} while (L.nextIf(','));
}
L.expect(']');
@ -296,6 +326,7 @@ struct SchemaParser {
IValue parseDefaultValue(
const c10::Type& arg_type,
TypeKind kind,
const c10::Type& real_type,
std::optional<int32_t> arg_N) {
auto range = L.cur().range;
switch (kind) {
@ -311,7 +342,7 @@ struct SchemaParser {
case TypeKind::BoolType:
case TypeKind::FloatType:
case TypeKind::ComplexType:
return parseSingleConstant(arg_type, kind);
return parseSingleConstant(arg_type, kind, real_type);
break;
case TypeKind::DeviceObjType: {
auto device_text =
@ -321,20 +352,24 @@ struct SchemaParser {
}
case TypeKind::ListType: {
auto elem_type = arg_type.containedType(0);
auto real_elem_type = real_type.containedType(0);
if (L.cur().kind == TK_IDENT) {
return parseTensorDefault(range);
} else if (arg_N && L.cur().kind != '[') {
IValue v = parseSingleConstant(*elem_type, elem_type->kind());
IValue v = parseSingleConstant(
*elem_type, elem_type->kind(), *real_elem_type);
std::vector<IValue> repeated(*arg_N, v);
return convertToList(*elem_type, elem_type->kind(), range, repeated);
} else {
return parseConstantList(*elem_type, elem_type->kind());
return parseConstantList(
*elem_type, elem_type->kind(), *real_elem_type);
}
} break;
case TypeKind::DynamicType:
return parseDefaultValue(
arg_type,
arg_type.expectRef<c10::DynamicType>().dynamicKind(),
real_type,
arg_N);
default:
throw ErrorReport(range) << "unexpected type, file a bug report";

View File

@ -168,6 +168,7 @@ std::optional<at::ScalarType> SchemaTypeParser::parseTensorDType(
static std::unordered_map<std::string, at::ScalarType> type_map = {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
#undef DEFINE_SCALAR_TYPE
auto type = type_map.find(dtype);
if (type != type_map.end()) {
return type->second;

View File

@ -7,91 +7,6 @@
namespace torch::utils {
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
switch (scalarType) {
case at::ScalarType::UInt1:
return std::make_pair("uint1", "bit");
case at::ScalarType::UInt2:
return std::make_pair("uint2", "");
case at::ScalarType::UInt3:
return std::make_pair("uint3", "");
case at::ScalarType::UInt4:
return std::make_pair("uint4", "");
case at::ScalarType::UInt5:
return std::make_pair("uint5", "");
case at::ScalarType::UInt6:
return std::make_pair("uint6", "");
case at::ScalarType::UInt7:
return std::make_pair("uint7", "");
case at::ScalarType::Byte:
// no "byte" because byte is signed in numpy and we overload
// byte to mean bool often
return std::make_pair("uint8", "");
case at::ScalarType::UInt16:
return std::make_pair("uint16", "");
case at::ScalarType::UInt32:
return std::make_pair("uint32", "");
case at::ScalarType::UInt64:
return std::make_pair("uint64", "");
case at::ScalarType::Char:
// no "char" because it is not consistently signed or unsigned; we want
// to move to int8
return std::make_pair("int8", "");
case at::ScalarType::Double:
return std::make_pair("float64", "double");
case at::ScalarType::Float:
return std::make_pair("float32", "float");
case at::ScalarType::Int:
return std::make_pair("int32", "int");
case at::ScalarType::Long:
return std::make_pair("int64", "long");
case at::ScalarType::Short:
return std::make_pair("int16", "short");
case at::ScalarType::Half:
return std::make_pair("float16", "half");
case at::ScalarType::ComplexHalf:
return std::make_pair("complex32", "chalf");
case at::ScalarType::ComplexFloat:
return std::make_pair("complex64", "cfloat");
case at::ScalarType::ComplexDouble:
return std::make_pair("complex128", "cdouble");
case at::ScalarType::Bool:
return std::make_pair("bool", "");
case at::ScalarType::QInt8:
return std::make_pair("qint8", "");
case at::ScalarType::QUInt8:
return std::make_pair("quint8", "");
case at::ScalarType::QInt32:
return std::make_pair("qint32", "");
case at::ScalarType::BFloat16:
return std::make_pair("bfloat16", "");
case at::ScalarType::QUInt4x2:
return std::make_pair("quint4x2", "");
case at::ScalarType::QUInt2x4:
return std::make_pair("quint2x4", "");
case at::ScalarType::Bits1x8:
return std::make_pair("bits1x8", "");
case at::ScalarType::Bits2x4:
return std::make_pair("bits2x4", "");
case at::ScalarType::Bits4x2:
return std::make_pair("bits4x2", "");
case at::ScalarType::Bits8:
return std::make_pair("bits8", "");
case at::ScalarType::Bits16:
return std::make_pair("bits16", "");
case at::ScalarType::Float8_e5m2:
return std::make_pair("float8_e5m2", "");
case at::ScalarType::Float8_e4m3fn:
return std::make_pair("float8_e4m3fn", "");
case at::ScalarType::Float8_e5m2fnuz:
return std::make_pair("float8_e5m2fnuz", "");
case at::ScalarType::Float8_e4m3fnuz:
return std::make_pair("float8_e4m3fnuz", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}
}
void initializeDtypes() {
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
if (!torch_module)
@ -102,8 +17,10 @@ void initializeDtypes() {
auto all_scalar_types = {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
#undef DEFINE_SCALAR_TYPE
for (at::ScalarType scalarType : all_scalar_types) {
auto [primary_name, legacy_name] = getDtypeNames(scalarType);
auto [primary_name, legacy_name] = c10::getDtypeNames(scalarType);
PyObject* dtype = THPDtype_New(scalarType, primary_name);
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
Py_INCREF(dtype);