mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace std::runtime_error
with TORCH_CHECK
(#159344)
Fixes part of #148114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159344 Approved by: https://github.com/cyyever, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
9aca0ba027
commit
e3783a9575
@ -62,7 +62,7 @@ static void setSignalHandler(
|
|||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "An error occurred while setting handler for " << strsignal(signal)
|
oss << "An error occurred while setting handler for " << strsignal(signal)
|
||||||
<< ".";
|
<< ".";
|
||||||
throw std::runtime_error(oss.str());
|
TORCH_CHECK(false, oss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,29 +141,32 @@ static PyObject* THPModule_errorIfAnyWorkerFails(
|
|||||||
continue;
|
continue;
|
||||||
if (infop.si_code == CLD_EXITED &&
|
if (infop.si_code == CLD_EXITED &&
|
||||||
infop.si_status != EXIT_SUCCESS) { // exit with error
|
infop.si_status != EXIT_SUCCESS) { // exit with error
|
||||||
std::ostringstream oss;
|
auto error_msg = fmt::format(
|
||||||
oss << "DataLoader worker (pid " << worker_pid << ") exited "
|
"DataLoader worker (pid {}) exited unexpectedly with exit code {}. "
|
||||||
<< "unexpectedly with exit code " << infop.si_status << ". "
|
"Details are lost due to multiprocessing. Rerunning with "
|
||||||
<< "Details are lost due to multiprocessing. Rerunning with "
|
"num_workers=0 may give better error trace.",
|
||||||
<< "num_workers=0 may give better error trace.";
|
worker_pid,
|
||||||
|
infop.si_status);
|
||||||
// This is necessary. Otherwise, the runtime error will kill the other
|
// This is necessary. Otherwise, the runtime error will kill the other
|
||||||
// workers, and trigger this again.
|
// workers, and trigger this again.
|
||||||
pid_set.clear();
|
pid_set.clear();
|
||||||
throw std::runtime_error(oss.str());
|
TORCH_CHECK(false, error_msg);
|
||||||
} else if (
|
} else if (
|
||||||
infop.si_code == CLD_KILLED ||
|
infop.si_code == CLD_KILLED ||
|
||||||
infop.si_code == CLD_DUMPED) { // killed by signal
|
infop.si_code == CLD_DUMPED) { // killed by signal
|
||||||
std::ostringstream oss;
|
auto error_msg = fmt::format(
|
||||||
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
|
"DataLoader worker (pid {}) is killed by signal: {}. ",
|
||||||
<< "by signal: " << strsignal(infop.si_status) << ". ";
|
worker_pid,
|
||||||
|
strsignal(infop.si_status));
|
||||||
if (infop.si_status == SIGBUS) {
|
if (infop.si_status == SIGBUS) {
|
||||||
oss << "It is possible that dataloader's workers are out of shared memory. "
|
error_msg +=
|
||||||
<< "Please try to raise your shared memory limit.";
|
"It is possible that dataloader's workers are out of shared memory. "
|
||||||
|
"Please try to raise your shared memory limit.";
|
||||||
}
|
}
|
||||||
// This is necessary. Otherwise, the runtime error will kill the other
|
// This is necessary. Otherwise, the runtime error will kill the other
|
||||||
// workers, and trigger this again.
|
// workers, and trigger this again.
|
||||||
pid_set.clear();
|
pid_set.clear();
|
||||||
throw std::runtime_error(oss.str());
|
TORCH_CHECK(false, error_msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,8 @@ static PyObject* THPDevice_pynew(
|
|||||||
auto as_device = r.device(0); // this works, because device can take strings
|
auto as_device = r.device(0); // this works, because device can take strings
|
||||||
if (as_device.has_index()) {
|
if (as_device.has_index()) {
|
||||||
auto device_type = r.string(0);
|
auto device_type = r.string(0);
|
||||||
throw std::runtime_error(
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
"type (string) must not include an index because index "
|
"type (string) must not include an index because index "
|
||||||
"was passed explicitly: " +
|
"was passed explicitly: " +
|
||||||
device_type);
|
device_type);
|
||||||
|
65
torch/csrc/utils/generated_serialization_types.h
generated
65
torch/csrc/utils/generated_serialization_types.h
generated
@ -10,6 +10,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
@ -190,7 +191,7 @@ inline std::string_view printEnum(const ArgumentKind& e) {
|
|||||||
case ArgumentKind::POSITIONAL: return "POSITIONAL";
|
case ArgumentKind::POSITIONAL: return "POSITIONAL";
|
||||||
case ArgumentKind::KEYWORD: return "KEYWORD";
|
case ArgumentKind::KEYWORD: return "KEYWORD";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,7 +199,7 @@ inline void parseEnum(std::string_view s, ArgumentKind& t) {
|
|||||||
if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; }
|
if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; }
|
||||||
if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; }
|
if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; }
|
||||||
if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; }
|
if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class Layout {
|
enum class Layout {
|
||||||
@ -223,7 +224,7 @@ inline std::string_view printEnum(const Layout& e) {
|
|||||||
case Layout::_mkldnn: return "_mkldnn";
|
case Layout::_mkldnn: return "_mkldnn";
|
||||||
case Layout::Strided: return "Strided";
|
case Layout::Strided: return "Strided";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +237,7 @@ inline void parseEnum(std::string_view s, Layout& t) {
|
|||||||
if (s == "SparseBsc") { t = Layout::SparseBsc; return; }
|
if (s == "SparseBsc") { t = Layout::SparseBsc; return; }
|
||||||
if (s == "_mkldnn") { t = Layout::_mkldnn; return; }
|
if (s == "_mkldnn") { t = Layout::_mkldnn; return; }
|
||||||
if (s == "Strided") { t = Layout::Strided; return; }
|
if (s == "Strided") { t = Layout::Strided; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class MemoryFormat {
|
enum class MemoryFormat {
|
||||||
@ -255,7 +256,7 @@ inline std::string_view printEnum(const MemoryFormat& e) {
|
|||||||
case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d";
|
case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d";
|
||||||
case MemoryFormat::PreserveFormat: return "PreserveFormat";
|
case MemoryFormat::PreserveFormat: return "PreserveFormat";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,7 +266,7 @@ inline void parseEnum(std::string_view s, MemoryFormat& t) {
|
|||||||
if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; }
|
if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; }
|
||||||
if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; }
|
if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; }
|
||||||
if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; }
|
if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class ScalarType {
|
enum class ScalarType {
|
||||||
@ -312,7 +313,7 @@ inline std::string_view printEnum(const ScalarType& e) {
|
|||||||
case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ";
|
case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ";
|
||||||
case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ";
|
case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -336,7 +337,7 @@ inline void parseEnum(std::string_view s, ScalarType& t) {
|
|||||||
if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; }
|
if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; }
|
||||||
if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; }
|
if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; }
|
||||||
if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; }
|
if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -453,7 +454,7 @@ inline std::string_view printEnum(const SymExprHint::Tag& e) {
|
|||||||
case SymExprHint::Tag::AS_BOOL: return "AS_BOOL";
|
case SymExprHint::Tag::AS_BOOL: return "AS_BOOL";
|
||||||
case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT";
|
case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -461,7 +462,7 @@ inline void parseEnum(std::string_view s, SymExprHint::Tag& t) {
|
|||||||
if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; }
|
if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; }
|
||||||
if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; }
|
if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; }
|
||||||
if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; }
|
if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -559,14 +560,14 @@ inline std::string_view printEnum(const SymInt::Tag& e) {
|
|||||||
case SymInt::Tag::AS_EXPR: return "AS_EXPR";
|
case SymInt::Tag::AS_EXPR: return "AS_EXPR";
|
||||||
case SymInt::Tag::AS_INT: return "AS_INT";
|
case SymInt::Tag::AS_INT: return "AS_INT";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, SymInt::Tag& t) {
|
inline void parseEnum(std::string_view s, SymInt::Tag& t) {
|
||||||
if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; }
|
if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; }
|
||||||
if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; }
|
if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -637,14 +638,14 @@ inline std::string_view printEnum(const SymFloat::Tag& e) {
|
|||||||
case SymFloat::Tag::AS_EXPR: return "AS_EXPR";
|
case SymFloat::Tag::AS_EXPR: return "AS_EXPR";
|
||||||
case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT";
|
case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, SymFloat::Tag& t) {
|
inline void parseEnum(std::string_view s, SymFloat::Tag& t) {
|
||||||
if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; }
|
if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; }
|
||||||
if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; }
|
if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -715,14 +716,14 @@ inline std::string_view printEnum(const SymBool::Tag& e) {
|
|||||||
case SymBool::Tag::AS_EXPR: return "AS_EXPR";
|
case SymBool::Tag::AS_EXPR: return "AS_EXPR";
|
||||||
case SymBool::Tag::AS_BOOL: return "AS_BOOL";
|
case SymBool::Tag::AS_BOOL: return "AS_BOOL";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, SymBool::Tag& t) {
|
inline void parseEnum(std::string_view s, SymBool::Tag& t) {
|
||||||
if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; }
|
if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; }
|
||||||
if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; }
|
if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -865,14 +866,14 @@ inline std::string_view printEnum(const SymIntArgument::Tag& e) {
|
|||||||
case SymIntArgument::Tag::AS_NAME: return "AS_NAME";
|
case SymIntArgument::Tag::AS_NAME: return "AS_NAME";
|
||||||
case SymIntArgument::Tag::AS_INT: return "AS_INT";
|
case SymIntArgument::Tag::AS_INT: return "AS_INT";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) {
|
inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) {
|
||||||
if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; }
|
if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; }
|
||||||
if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; }
|
if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -943,14 +944,14 @@ inline std::string_view printEnum(const SymFloatArgument::Tag& e) {
|
|||||||
case SymFloatArgument::Tag::AS_NAME: return "AS_NAME";
|
case SymFloatArgument::Tag::AS_NAME: return "AS_NAME";
|
||||||
case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT";
|
case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) {
|
inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) {
|
||||||
if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; }
|
if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; }
|
||||||
if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; }
|
if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1021,14 +1022,14 @@ inline std::string_view printEnum(const SymBoolArgument::Tag& e) {
|
|||||||
case SymBoolArgument::Tag::AS_NAME: return "AS_NAME";
|
case SymBoolArgument::Tag::AS_NAME: return "AS_NAME";
|
||||||
case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL";
|
case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) {
|
inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) {
|
||||||
if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; }
|
if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; }
|
||||||
if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; }
|
if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1135,14 +1136,14 @@ inline std::string_view printEnum(const OptionalTensorArgument::Tag& e) {
|
|||||||
case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR";
|
case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR";
|
||||||
case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE";
|
case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) {
|
inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) {
|
||||||
if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; }
|
if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; }
|
||||||
if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; }
|
if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1769,7 +1770,7 @@ inline std::string_view printEnum(const Argument::Tag& e) {
|
|||||||
case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR";
|
case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR";
|
||||||
case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX";
|
case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1801,7 +1802,7 @@ inline void parseEnum(std::string_view s, Argument::Tag& t) {
|
|||||||
if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; }
|
if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; }
|
||||||
if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; }
|
if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; }
|
||||||
if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; }
|
if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2127,7 +2128,7 @@ inline std::string_view printEnum(const ConstantValue::Tag& e) {
|
|||||||
case ConstantValue::Tag::AS_STRING: return "AS_STRING";
|
case ConstantValue::Tag::AS_STRING: return "AS_STRING";
|
||||||
case ConstantValue::Tag::AS_BOOL: return "AS_BOOL";
|
case ConstantValue::Tag::AS_BOOL: return "AS_BOOL";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2137,7 +2138,7 @@ inline void parseEnum(std::string_view s, ConstantValue::Tag& t) {
|
|||||||
if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; }
|
if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; }
|
||||||
if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; }
|
if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; }
|
||||||
if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; }
|
if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2465,7 +2466,7 @@ inline std::string_view printEnum(const InputSpec::Tag& e) {
|
|||||||
case InputSpec::Tag::TOKEN: return "TOKEN";
|
case InputSpec::Tag::TOKEN: return "TOKEN";
|
||||||
case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT";
|
case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2477,7 +2478,7 @@ inline void parseEnum(std::string_view s, InputSpec::Tag& t) {
|
|||||||
if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; }
|
if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; }
|
||||||
if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; }
|
if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; }
|
||||||
if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; }
|
if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -2851,7 +2852,7 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) {
|
|||||||
case OutputSpec::Tag::TOKEN: return "TOKEN";
|
case OutputSpec::Tag::TOKEN: return "TOKEN";
|
||||||
case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION";
|
case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown enum value");
|
TORCH_CHECK(false, "Unknown enum value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2864,7 +2865,7 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) {
|
|||||||
if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; }
|
if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; }
|
||||||
if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; }
|
if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; }
|
||||||
if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; }
|
if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; }
|
||||||
throw std::runtime_error("Unknown enum value: " + std::string{s});
|
TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,9 +131,8 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
|||||||
size(0),
|
size(0),
|
||||||
default_scalar(0) {
|
default_scalar(0) {
|
||||||
auto space = fmt.find(' ');
|
auto space = fmt.find(' ');
|
||||||
if (space == std::string::npos) {
|
TORCH_CHECK(
|
||||||
throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
|
space != std::string::npos, "FunctionParameter(): missing type: " + fmt);
|
||||||
}
|
|
||||||
|
|
||||||
auto type_str = fmt.substr(0, space);
|
auto type_str = fmt.substr(0, space);
|
||||||
|
|
||||||
@ -154,10 +153,9 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
|||||||
|
|
||||||
auto name_str = fmt.substr(space + 1);
|
auto name_str = fmt.substr(space + 1);
|
||||||
auto it = type_map.find(type_str);
|
auto it = type_map.find(type_str);
|
||||||
if (it == type_map.end()) {
|
TORCH_CHECK(
|
||||||
throw std::runtime_error(
|
it != type_map.end(),
|
||||||
"FunctionParameter(): invalid type string: " + type_str);
|
"FunctionParameter(): invalid type string: " + type_str);
|
||||||
}
|
|
||||||
type_ = it->second;
|
type_ = it->second;
|
||||||
|
|
||||||
auto eq = name_str.find('=');
|
auto eq = name_str.find('=');
|
||||||
@ -1145,7 +1143,7 @@ auto FunctionParameter::_check(
|
|||||||
case ParameterType::DISPATCH_KEY_SET:
|
case ParameterType::DISPATCH_KEY_SET:
|
||||||
return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
|
return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown parameter type");
|
TORCH_CHECK(false, "unknown parameter type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1202,7 +1200,7 @@ std::string FunctionParameter::type_name() const {
|
|||||||
case ParameterType::DISPATCH_KEY_SET:
|
case ParameterType::DISPATCH_KEY_SET:
|
||||||
return "DispatchKeySet";
|
return "DispatchKeySet";
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown parameter type");
|
TORCH_CHECK(false, "unknown parameter type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1324,10 +1322,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|||||||
}
|
}
|
||||||
if (type_ == ParameterType::TENSOR ||
|
if (type_ == ParameterType::TENSOR ||
|
||||||
type_ == ParameterType::DISPATCH_KEY_SET) {
|
type_ == ParameterType::DISPATCH_KEY_SET) {
|
||||||
if (str != "None") {
|
TORCH_CHECK(
|
||||||
throw std::runtime_error(
|
str == "None", "default value for Tensor must be none, got: " + str);
|
||||||
"default value for Tensor must be none, got: " + str);
|
|
||||||
}
|
|
||||||
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
|
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
|
||||||
default_int = atol(str.c_str());
|
default_int = atol(str.c_str());
|
||||||
} else if (type_ == ParameterType::BOOL) {
|
} else if (type_ == ParameterType::BOOL) {
|
||||||
@ -1351,16 +1347,14 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|||||||
default_intlist = parse_intlist_args(str, size);
|
default_intlist = parse_intlist_args(str, size);
|
||||||
}
|
}
|
||||||
} else if (type_ == ParameterType::FLOAT_LIST) {
|
} else if (type_ == ParameterType::FLOAT_LIST) {
|
||||||
if (str != "None") {
|
TORCH_CHECK(str == "None", "Defaults not supported for float[]");
|
||||||
throw std::runtime_error("Defaults not supported for float[]");
|
|
||||||
}
|
|
||||||
} else if (type_ == ParameterType::SCALARTYPE) {
|
} else if (type_ == ParameterType::SCALARTYPE) {
|
||||||
if (str == "None") {
|
if (str == "None") {
|
||||||
default_scalartype = at::ScalarType::Undefined;
|
default_scalartype = at::ScalarType::Undefined;
|
||||||
} else if (str == "torch.int64") {
|
} else if (str == "torch.int64") {
|
||||||
default_scalartype = at::ScalarType::Long;
|
default_scalartype = at::ScalarType::Long;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("invalid default value for ScalarType: " + str);
|
TORCH_CHECK(false, "invalid default value for ScalarType: " + str);
|
||||||
}
|
}
|
||||||
} else if (type_ == ParameterType::LAYOUT) {
|
} else if (type_ == ParameterType::LAYOUT) {
|
||||||
if (str == "None") {
|
if (str == "None") {
|
||||||
@ -1370,16 +1364,12 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|||||||
} else if (str == "torch.sparse_coo") {
|
} else if (str == "torch.sparse_coo") {
|
||||||
default_layout = at::Layout::Sparse;
|
default_layout = at::Layout::Sparse;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("invalid default value for layout: " + str);
|
TORCH_CHECK(false, "invalid default value for layout: " + str);
|
||||||
}
|
}
|
||||||
} else if (type_ == ParameterType::DEVICE) {
|
} else if (type_ == ParameterType::DEVICE) {
|
||||||
if (str != "None") {
|
TORCH_CHECK(str == "None", "invalid device: " + str);
|
||||||
throw std::runtime_error("invalid device: " + str);
|
|
||||||
}
|
|
||||||
} else if (type_ == ParameterType::STREAM) {
|
} else if (type_ == ParameterType::STREAM) {
|
||||||
if (str != "None") {
|
TORCH_CHECK(str == "None", "invalid stream: " + str);
|
||||||
throw std::runtime_error("invalid stream: " + str);
|
|
||||||
}
|
|
||||||
} else if (type_ == ParameterType::STRING) {
|
} else if (type_ == ParameterType::STRING) {
|
||||||
if (str != "None") {
|
if (str != "None") {
|
||||||
default_string = parse_string_literal(str);
|
default_string = parse_string_literal(str);
|
||||||
@ -1408,7 +1398,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|||||||
} else if (type_ == ParameterType::QSCHEME) { // NOLINT
|
} else if (type_ == ParameterType::QSCHEME) { // NOLINT
|
||||||
// throw std::runtime_error("ParameterType::QSCHEME");
|
// throw std::runtime_error("ParameterType::QSCHEME");
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("unknown parameter type");
|
TORCH_CHECK(false, "unknown parameter type");
|
||||||
}
|
}
|
||||||
default_value = str;
|
default_value = str;
|
||||||
}
|
}
|
||||||
@ -1423,7 +1413,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
|||||||
deprecated(false) {
|
deprecated(false) {
|
||||||
auto open_paren = fmt.find('(');
|
auto open_paren = fmt.find('(');
|
||||||
if (open_paren == std::string::npos) {
|
if (open_paren == std::string::npos) {
|
||||||
throw std::runtime_error("missing opening parenthesis: " + fmt);
|
TORCH_CHECK(false, "missing opening parenthesis: " + fmt);
|
||||||
}
|
}
|
||||||
name = fmt.substr(0, open_paren);
|
name = fmt.substr(0, open_paren);
|
||||||
|
|
||||||
@ -1445,12 +1435,9 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (offset == std::string::npos) {
|
TORCH_CHECK(
|
||||||
throw std::runtime_error("missing closing parenthesis: " + fmt);
|
offset != std::string::npos, "missing closing parenthesis: " + fmt);
|
||||||
}
|
TORCH_CHECK(offset != last_offset, "malformed signature: " + fmt);
|
||||||
if (offset == last_offset) {
|
|
||||||
throw std::runtime_error("malformed signature: " + fmt);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto param_str = fmt.substr(last_offset, offset - last_offset);
|
auto param_str = fmt.substr(last_offset, offset - last_offset);
|
||||||
last_offset = next_offset;
|
last_offset = next_offset;
|
||||||
|
@ -120,7 +120,7 @@ inline bool THPUtils_unpackBool(PyObject* obj) {
|
|||||||
} else if (obj == Py_False) {
|
} else if (obj == Py_False) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("couldn't convert python object to boolean");
|
TORCH_CHECK(false, "couldn't convert python object to boolean");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,13 +199,11 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) {
|
|||||||
if (value == -1 && PyErr_Occurred()) {
|
if (value == -1 && PyErr_Occurred()) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
if (overflow != 0) {
|
TORCH_CHECK(overflow == 0, "Overflow when unpacking DeviceIndex");
|
||||||
throw std::runtime_error("Overflow when unpacking DeviceIndex");
|
TORCH_CHECK(
|
||||||
}
|
value <= std::numeric_limits<c10::DeviceIndex>::max() &&
|
||||||
if (value > std::numeric_limits<c10::DeviceIndex>::max() ||
|
value >= std::numeric_limits<c10::DeviceIndex>::min(),
|
||||||
value < std::numeric_limits<c10::DeviceIndex>::min()) {
|
"Overflow when unpacking DeviceIndex");
|
||||||
throw std::runtime_error("Overflow when unpacking DeviceIndex");
|
|
||||||
}
|
|
||||||
return (c10::DeviceIndex)value;
|
return (c10::DeviceIndex)value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
|||||||
at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj));
|
at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("store_scalar: invalid type");
|
TORCH_CHECK(false, "store_scalar: invalid type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
|
|||||||
return PyFloat_FromDouble(
|
return PyFloat_FromDouble(
|
||||||
at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data));
|
at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data));
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("load_scalar: invalid type");
|
TORCH_CHECK(false, "load_scalar: invalid type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,12 +26,10 @@ inline std::string THPUtils_unpackString(PyObject* obj) {
|
|||||||
if (PyUnicode_Check(obj)) {
|
if (PyUnicode_Check(obj)) {
|
||||||
Py_ssize_t size = 0;
|
Py_ssize_t size = 0;
|
||||||
const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
|
const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
|
||||||
if (!data) {
|
TORCH_CHECK(data, "error unpacking string as utf-8");
|
||||||
throw std::runtime_error("error unpacking string as utf-8");
|
|
||||||
}
|
|
||||||
return std::string(data, (size_t)size);
|
return std::string(data, (size_t)size);
|
||||||
}
|
}
|
||||||
throw std::runtime_error("unpackString: expected bytes or unicode object");
|
TORCH_CHECK(false, "unpackString: expected bytes or unicode object");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unpacks PyBytes (PyString) or PyUnicode as std::string_view
|
// Unpacks PyBytes (PyString) or PyUnicode as std::string_view
|
||||||
@ -50,12 +48,10 @@ inline std::string_view THPUtils_unpackStringView(PyObject* obj) {
|
|||||||
if (PyUnicode_Check(obj)) {
|
if (PyUnicode_Check(obj)) {
|
||||||
Py_ssize_t size = 0;
|
Py_ssize_t size = 0;
|
||||||
const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
|
const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
|
||||||
if (!data) {
|
TORCH_CHECK(data, "error unpacking string as utf-8");
|
||||||
throw std::runtime_error("error unpacking string as utf-8");
|
|
||||||
}
|
|
||||||
return std::string_view(data, (size_t)size);
|
return std::string_view(data, (size_t)size);
|
||||||
}
|
}
|
||||||
throw std::runtime_error("unpackString: expected bytes or unicode object");
|
TORCH_CHECK(false, "unpackString: expected bytes or unicode object");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline PyObject* THPUtils_packString(const char* str) {
|
inline PyObject* THPUtils_packString(const char* str) {
|
||||||
|
@ -689,7 +689,7 @@ Tensor legacy_sparse_tensor_generic_ctor_new(
|
|||||||
return new_with_sizes(
|
return new_with_sizes(
|
||||||
options, scalar_type, deviceOptional, r.symintlist(0));
|
options, scalar_type, deviceOptional, r.symintlist(0));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new(): invalid arguments");
|
TORCH_CHECK(false, "new(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
// NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
|
// NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
|
||||||
@ -808,7 +808,7 @@ static Tensor legacy_tensor_generic_ctor_new(
|
|||||||
return legacy_new_from_sequence(
|
return legacy_new_from_sequence(
|
||||||
options, scalar_type, deviceOptional, r.pyobject(0));
|
options, scalar_type, deviceOptional, r.pyobject(0));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new(): invalid arguments");
|
TORCH_CHECK(false, "new(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handles ONLY torch.Tensor
|
// Handles ONLY torch.Tensor
|
||||||
@ -1072,7 +1072,7 @@ static Tensor sparse_compressed_tensor_ctor_worker(
|
|||||||
values.options().layout(layout).pinned_memory(pin_memory))
|
values.options().layout(layout).pinned_memory(pin_memory))
|
||||||
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1));
|
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1));
|
||||||
}
|
}
|
||||||
throw std::runtime_error(name + ": invalid arguments");
|
TORCH_CHECK(false, name + ": invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor sparse_compressed_tensor_ctor(
|
Tensor sparse_compressed_tensor_ctor(
|
||||||
@ -1274,7 +1274,7 @@ Tensor sparse_coo_tensor_ctor(
|
|||||||
inferred_options.dtype(inferred_scalar_type).layout(at::kSparse))
|
inferred_options.dtype(inferred_scalar_type).layout(at::kSparse))
|
||||||
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2));
|
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
|
TORCH_CHECK(false, "sparse_coo_tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
void _validate_sparse_coo_tensor_args(
|
void _validate_sparse_coo_tensor_args(
|
||||||
@ -1504,7 +1504,7 @@ Tensor tensor_ctor(
|
|||||||
new_tensor.set_requires_grad(args_requires_grad);
|
new_tensor.set_requires_grad(args_requires_grad);
|
||||||
return new_tensor;
|
return new_tensor;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("tensor(): invalid arguments");
|
TORCH_CHECK(false, "tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor as_tensor(
|
Tensor as_tensor(
|
||||||
@ -1523,7 +1523,7 @@ Tensor as_tensor(
|
|||||||
/*copy_numpy=*/false,
|
/*copy_numpy=*/false,
|
||||||
/*type_inference=*/type_inference);
|
/*type_inference=*/type_inference);
|
||||||
}
|
}
|
||||||
throw std::runtime_error("tensor(): invalid arguments");
|
TORCH_CHECK(false, "tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor new_tensor(
|
Tensor new_tensor(
|
||||||
@ -1561,7 +1561,7 @@ Tensor new_tensor(
|
|||||||
new_tensor.set_requires_grad(args_requires_grad);
|
new_tensor.set_requires_grad(args_requires_grad);
|
||||||
return new_tensor;
|
return new_tensor;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_tensor(): invalid arguments");
|
TORCH_CHECK(false, "new_tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor tensor_frombuffer(
|
Tensor tensor_frombuffer(
|
||||||
|
@ -9,32 +9,32 @@
|
|||||||
|
|
||||||
namespace torch::utils {
|
namespace torch::utils {
|
||||||
PyObject* tensor_to_numpy(const at::Tensor&, bool) {
|
PyObject* tensor_to_numpy(const at::Tensor&, bool) {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
at::Tensor tensor_from_numpy(
|
at::Tensor tensor_from_numpy(
|
||||||
PyObject* obj,
|
PyObject* obj,
|
||||||
bool warn_if_not_writeable /*=true*/) {
|
bool warn_if_not_writeable /*=true*/) {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_numpy_available() {
|
bool is_numpy_available() {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_numpy_int(PyObject* obj) {
|
bool is_numpy_int(PyObject* obj) {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
bool is_numpy_scalar(PyObject* obj) {
|
bool is_numpy_scalar(PyObject* obj) {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
at::Tensor tensor_from_cuda_array_interface(
|
at::Tensor tensor_from_cuda_array_interface(
|
||||||
PyObject* obj,
|
PyObject* obj,
|
||||||
std::optional<c10::Device> device_opt) {
|
std::optional<c10::Device> device_opt) {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
|
|
||||||
void warn_numpy_not_writeable() {
|
void warn_numpy_not_writeable() {
|
||||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
|
||||||
}
|
}
|
||||||
|
|
||||||
// No-op stubs.
|
// No-op stubs.
|
||||||
@ -215,9 +215,7 @@ void warn_numpy_not_writeable() {
|
|||||||
at::Tensor tensor_from_numpy(
|
at::Tensor tensor_from_numpy(
|
||||||
PyObject* obj,
|
PyObject* obj,
|
||||||
bool warn_if_not_writeable /*=true*/) {
|
bool warn_if_not_writeable /*=true*/) {
|
||||||
if (!is_numpy_available()) {
|
TORCH_CHECK(is_numpy_available(), "Numpy is not available");
|
||||||
throw std::runtime_error("Numpy is not available");
|
|
||||||
}
|
|
||||||
TORCH_CHECK_TYPE(
|
TORCH_CHECK_TYPE(
|
||||||
PyArray_Check(obj),
|
PyArray_Check(obj),
|
||||||
"expected np.ndarray (got ",
|
"expected np.ndarray (got ",
|
||||||
@ -385,9 +383,7 @@ bool is_numpy_scalar(PyObject* obj) {
|
|||||||
at::Tensor tensor_from_cuda_array_interface(
|
at::Tensor tensor_from_cuda_array_interface(
|
||||||
PyObject* obj,
|
PyObject* obj,
|
||||||
std::optional<c10::Device> device_opt) {
|
std::optional<c10::Device> device_opt) {
|
||||||
if (!is_numpy_available()) {
|
TORCH_CHECK(is_numpy_available(), "Numpy is not available");
|
||||||
throw std::runtime_error("Numpy is not available");
|
|
||||||
}
|
|
||||||
auto cuda_dict =
|
auto cuda_dict =
|
||||||
THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__"));
|
THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__"));
|
||||||
TORCH_INTERNAL_ASSERT(cuda_dict);
|
TORCH_INTERNAL_ASSERT(cuda_dict);
|
||||||
|
Reference in New Issue
Block a user