mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Teach Python TS frontend to parse complex literals (#52881)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52881 **This PR adds:** 1. logic to parse complex constants (complex literals of the form `bj`) 2. logic to parse complex lists 3. support for complex constructors: `complex(tensor/int/float/bool, tensor/int/float/bool)` 4. Limited operator support - `add`, `sub`, `mul`, `torch.tensor`, `torch.as_tensor` **Follow-up work:** 1. Add complex support for unary and other registered ops. 2. support complex constructor with string as input (this is supported in Python eager mode). 3. Test all emitXYZ for all XYZ in `ir_emitter.cpp` (currently only emitConst, emitValueToTensor are tested). e.g., test loops etc. 4. onnx doesn't support complex tensors, so we should error out with a clear and descriptive error message. Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D27245059 Pulled By: anjali411 fbshipit-source-id: af043b5159ae99a9cc8691b5a8401503fa8d6f05
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2f5db68797
commit
f9ca0d87a7
@ -173,6 +173,8 @@ struct SchemaParser {
|
||||
auto text = tok.text();
|
||||
if ("float" == text) {
|
||||
return static_cast<int64_t>(at::kFloat);
|
||||
} else if ("complex" == text) {
|
||||
return static_cast<int64_t>(at::kComplexFloat);
|
||||
} else if ("long" == text) {
|
||||
return static_cast<int64_t>(at::kLong);
|
||||
} else if ("strided" == text) {
|
||||
@ -191,7 +193,12 @@ struct SchemaParser {
|
||||
n = "-" + L.expect(TK_NUMBER).text();
|
||||
else
|
||||
n = L.expect(TK_NUMBER).text();
|
||||
if (kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
|
||||
|
||||
if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) {
|
||||
auto imag = c10::stod(n.substr(0, n.size() - 1));
|
||||
return c10::complex<double>(0, imag);
|
||||
} else if (
|
||||
kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
|
||||
n.find('e') != std::string::npos) {
|
||||
return c10::stod(n);
|
||||
} else {
|
||||
@ -205,6 +212,8 @@ struct SchemaParser {
|
||||
const SourceRange& range,
|
||||
const std::vector<IValue>& vs) {
|
||||
switch (kind) {
|
||||
case TypeKind::ComplexType:
|
||||
return fmap(vs, [](const IValue& v) { return v.toComplexDouble(); });
|
||||
case TypeKind::FloatType:
|
||||
return fmap(vs, [](const IValue& v) { return v.toDouble(); });
|
||||
case TypeKind::IntType:
|
||||
@ -213,7 +222,7 @@ struct SchemaParser {
|
||||
return fmap(vs, [](const IValue& v) { return v.toBool(); });
|
||||
default:
|
||||
throw ErrorReport(range)
|
||||
<< "lists are only supported for float or int types";
|
||||
<< "lists are only supported for float, int and complex types";
|
||||
}
|
||||
}
|
||||
IValue parseConstantList(TypeKind kind) {
|
||||
@ -248,6 +257,7 @@ struct SchemaParser {
|
||||
case TypeKind::IntType:
|
||||
case TypeKind::BoolType:
|
||||
case TypeKind::FloatType:
|
||||
case TypeKind::ComplexType:
|
||||
return parseSingleConstant(arg_type->kind());
|
||||
break;
|
||||
case TypeKind::DeviceObjType: {
|
||||
|
Reference in New Issue
Block a user