mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add better device idx parse checks (#37376)
Summary: Fixes https://github.com/pytorch/pytorch/issues/32079 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37376 Differential Revision: D21476036 Pulled By: zou3519 fbshipit-source-id: 86907083c23cbaf165b645307fb340f2656b814e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0a159b0a3a
commit
ae392a77a6
@ -9,6 +9,23 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
// Check if compiler has working std::regex implementation
|
||||
//
|
||||
// Test below is adapted from https://stackoverflow.com/a/41186162
|
||||
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
|
||||
// Compiler has working regex. MSVC has erroneous __cplusplus.
|
||||
#elif __cplusplus >= 201103L && \
|
||||
(!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
|
||||
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
|
||||
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
|
||||
(defined(_GLIBCXX_RELEASE) && \
|
||||
_GLIBCXX_RELEASE > 4)))
|
||||
// Compiler has working regex.
|
||||
#else
|
||||
static_assert(false, "Compiler does not have proper regex support.");
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
namespace {
|
||||
@ -38,49 +55,24 @@ DeviceType parse_type(const std::string& device_string) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// `std::regex` is still in a very incomplete state in GCC 4.8.x,
|
||||
// so we have to do our own parsing, like peasants.
|
||||
// https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions
|
||||
//
|
||||
// Replace with the following code once we shed our GCC skin:
|
||||
//
|
||||
// static const std::regex regex(
|
||||
// "(cuda|cpu)|(cuda|cpu):([0-9]+)|([0-9]+)",
|
||||
// std::regex_constants::basic);
|
||||
// std::smatch match;
|
||||
// const bool ok = std::regex_match(device_string, match, regex);
|
||||
// TORCH_CHECK(ok, "Invalid device string: '", device_string, "'");
|
||||
// if (match[1].matched) {
|
||||
// type_ = parse_type_from_string(match[1].str());
|
||||
// } else {
|
||||
// if (match[2].matched) {
|
||||
// type_ = parse_type_from_string(match[1].str());
|
||||
// } else {
|
||||
// type_ = Type::CUDA;
|
||||
// }
|
||||
// AT_ASSERT(match[3].matched);
|
||||
// index_ = std::stoi(match[3].str());
|
||||
// }
|
||||
Device::Device(const std::string& device_string) : Device(Type::CPU) {
|
||||
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
|
||||
auto index = device_string.find(':');
|
||||
if (index == std::string::npos) {
|
||||
type_ = parse_type(device_string);
|
||||
} else {
|
||||
std::string s;
|
||||
s = device_string.substr(0, index);
|
||||
TORCH_CHECK(!s.empty(), "Device string must not be empty");
|
||||
type_ = parse_type(s);
|
||||
|
||||
std::string device_index = device_string.substr(index + 1);
|
||||
// We assume gcc 5+, so we can use proper regex.
|
||||
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
|
||||
std::smatch match;
|
||||
TORCH_CHECK(
|
||||
std::regex_match(device_string, match, regex),
|
||||
"Invalid device string: '", device_string, "'");
|
||||
type_ = parse_type(match[1].str());
|
||||
if (match[2].matched) {
|
||||
try {
|
||||
index_ = c10::stoi(device_index);
|
||||
index_ = c10::stoi(match[2].str());
|
||||
} catch (const std::exception &) {
|
||||
AT_ERROR("Could not parse device index '", device_index,
|
||||
"' in device string '", device_string, "'");
|
||||
AT_ERROR(
|
||||
"Could not parse device index '", match[2].str(),
|
||||
"' in device string '", device_string, "'");
|
||||
}
|
||||
TORCH_CHECK(index_ >= 0,
|
||||
"Device index must be non-negative, got ", index_);
|
||||
}
|
||||
validate();
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -22,6 +23,12 @@ inline int stoi(const std::string& str, std::size_t* pos = 0) {
|
||||
int n = 0;
|
||||
ss << str;
|
||||
ss >> n;
|
||||
if (ss.fail()) {
|
||||
// To mimic `std::stoi` and to avoid including `Exception.h`, throw
|
||||
// `std::invalid_argument`.
|
||||
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
|
||||
throw std::invalid_argument("Not an integer");
|
||||
}
|
||||
if (pos) {
|
||||
if (ss.tellg() == std::streampos(-1)) {
|
||||
*pos = str.size();
|
||||
@ -37,6 +44,12 @@ inline uint64_t stoull(const std::string& str) {
|
||||
uint64_t n = 0;
|
||||
ss << str;
|
||||
ss >> n;
|
||||
if (ss.fail()) {
|
||||
// To mimic `std::stoull` and to avoid including `Exception.h`, throw
|
||||
// `std::invalid_argument`.
|
||||
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
|
||||
throw std::invalid_argument("Not an unsigned 64-bit integer");
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
@ -45,6 +58,12 @@ inline double stod(const std::string& str, std::size_t* pos = 0) {
|
||||
ss << str;
|
||||
double val = 0;
|
||||
ss >> val;
|
||||
if (ss.fail()) {
|
||||
// To mimic `std::stod` and to avoid including `Exception.h`, throw
|
||||
// `std::invalid_argument`.
|
||||
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
|
||||
throw std::invalid_argument("Not a double-precision floating point number");
|
||||
}
|
||||
if (pos) {
|
||||
if (ss.tellg() == std::streampos(-1)) {
|
||||
*pos = str.size();
|
||||
@ -62,6 +81,12 @@ inline long long stoll(const std::string& str, std::size_t* pos = 0) {
|
||||
ss << str;
|
||||
long long result = 0;
|
||||
ss >> result;
|
||||
if (ss.fail()) {
|
||||
// To mimic `std::stoll` and to avoid including `Exception.h`, throw
|
||||
// `std::invalid_argument`.
|
||||
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
|
||||
throw std::invalid_argument("Not a long long integer");
|
||||
}
|
||||
if (pos) {
|
||||
if (ss.tellg() == std::streampos(-1)) {
|
||||
*pos = str.size();
|
||||
@ -75,18 +100,24 @@ inline long long stoll(const std::string& str, std::size_t* pos = 0) {
|
||||
inline long long stoll(const std::string& str, size_t pos, int base) {
|
||||
// std::stoll doesn't exist in our Android environment, we need to implement
|
||||
// it ourselves.
|
||||
std::stringstream s;
|
||||
std::stringstream ss;
|
||||
if (str.size() > 0 && str.at(0) == '0') {
|
||||
if (str.size() > 1 && (str.at(1) == 'x' || str.at(1) == 'X')) {
|
||||
s << std::hex << str;
|
||||
ss << std::hex << str;
|
||||
} else {
|
||||
s << std::oct << str;
|
||||
ss << std::oct << str;
|
||||
}
|
||||
} else {
|
||||
s << str;
|
||||
ss << str;
|
||||
}
|
||||
long long result = 0;
|
||||
s >> result;
|
||||
ss >> result;
|
||||
if (ss.fail()) {
|
||||
// To mimic `std::stoll` and to avoid including `Exception.h`, throw
|
||||
// `std::invalid_argument`.
|
||||
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
|
||||
throw std::invalid_argument("Not a long long integer");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -853,7 +853,7 @@ class TestCuda(TestCase):
|
||||
import os
|
||||
fname = "tempfile.pt"
|
||||
try:
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
|
||||
torch.save([torch.nn.Parameter(torch.randn(10, 10))], fname,
|
||||
_use_new_zipfile_serialization=True)
|
||||
torch.load(fname, 'cuda0')
|
||||
@ -863,7 +863,7 @@ class TestCuda(TestCase):
|
||||
|
||||
def test_get_device_index(self):
|
||||
from torch.cuda._utils import _get_device_index
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
|
||||
_get_device_index('cuda0', optional=True)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Expected a cuda device"):
|
||||
|
@ -790,11 +790,32 @@ class _TestTorchMixin(object):
|
||||
self.assertEqual('cuda', cuda1.type)
|
||||
self.assertEqual(1, cuda1.index)
|
||||
|
||||
cuda90 = torch.device('cuda', 90)
|
||||
self.assertEqual('cuda:90', str(cuda90))
|
||||
self.assertEqual('cuda', cuda90.type)
|
||||
self.assertEqual(90, cuda90.index)
|
||||
|
||||
cuda23333 = torch.device('cuda', 23333)
|
||||
self.assertEqual('cuda:23333', str(cuda23333))
|
||||
self.assertEqual('cuda', cuda23333.type)
|
||||
self.assertEqual(23333, cuda23333.index)
|
||||
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cpu:1'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 '))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 2'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2?'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:?2'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.232'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3'))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1))
|
||||
self.assertRaises(RuntimeError, lambda: torch.device(-1))
|
||||
|
||||
|
Reference in New Issue
Block a user