Add better device idx parse checks (#37376)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/32079
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37376

Differential Revision: D21476036

Pulled By: zou3519

fbshipit-source-id: 86907083c23cbaf165b645307fb340f2656b814e
This commit is contained in:
SsnL
2020-05-14 09:03:52 -07:00
committed by Facebook GitHub Bot
parent 0a159b0a3a
commit ae392a77a6
4 changed files with 88 additions and 44 deletions

View File

@ -9,6 +9,23 @@
#include <string>
#include <tuple>
#include <vector>
#include <regex>
// Check if compiler has working std::regex implementation
//
// Test below is adapted from https://stackoverflow.com/a/41186162
#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
// Compiler has working regex. MSVC has erroneous __cplusplus.
#elif __cplusplus >= 201103L && \
(!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
(defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
(defined(_GLIBCXX_RELEASE) && \
_GLIBCXX_RELEASE > 4)))
// Compiler has working regex.
#else
static_assert(false, "Compiler does not have proper regex support.");
#endif
namespace c10 {
namespace {
@ -38,49 +55,24 @@ DeviceType parse_type(const std::string& device_string) {
}
} // namespace
// `std::regex` is still in a very incomplete state in GCC 4.8.x,
// so we have to do our own parsing, like peasants.
// https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions
//
// Replace with the following code once we shed our GCC skin:
//
// static const std::regex regex(
// "(cuda|cpu)|(cuda|cpu):([0-9]+)|([0-9]+)",
// std::regex_constants::basic);
// std::smatch match;
// const bool ok = std::regex_match(device_string, match, regex);
// TORCH_CHECK(ok, "Invalid device string: '", device_string, "'");
// if (match[1].matched) {
// type_ = parse_type_from_string(match[1].str());
// } else {
// if (match[2].matched) {
// type_ = parse_type_from_string(match[1].str());
// } else {
// type_ = Type::CUDA;
// }
// AT_ASSERT(match[3].matched);
// index_ = std::stoi(match[3].str());
// }
Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
auto index = device_string.find(':');
if (index == std::string::npos) {
type_ = parse_type(device_string);
} else {
std::string s;
s = device_string.substr(0, index);
TORCH_CHECK(!s.empty(), "Device string must not be empty");
type_ = parse_type(s);
std::string device_index = device_string.substr(index + 1);
// We assume gcc 5+, so we can use proper regex.
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
std::smatch match;
TORCH_CHECK(
std::regex_match(device_string, match, regex),
"Invalid device string: '", device_string, "'");
type_ = parse_type(match[1].str());
if (match[2].matched) {
try {
index_ = c10::stoi(device_index);
index_ = c10::stoi(match[2].str());
} catch (const std::exception &) {
AT_ERROR("Could not parse device index '", device_index,
"' in device string '", device_string, "'");
AT_ERROR(
"Could not parse device index '", match[2].str(),
"' in device string '", device_string, "'");
}
TORCH_CHECK(index_ >= 0,
"Device index must be non-negative, got ", index_);
}
validate();
}

View File

@ -2,6 +2,7 @@
#include <sstream>
#include <string>
#include <stdexcept>
namespace c10 {
@ -22,6 +23,12 @@ inline int stoi(const std::string& str, std::size_t* pos = 0) {
int n = 0;
ss << str;
ss >> n;
if (ss.fail()) {
// To mimic `std::stoi` and to avoid including `Exception.h`, throw
// `std::invalid_argument`.
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
throw std::invalid_argument("Not an integer");
}
if (pos) {
if (ss.tellg() == std::streampos(-1)) {
*pos = str.size();
@ -37,6 +44,12 @@ inline uint64_t stoull(const std::string& str) {
uint64_t n = 0;
ss << str;
ss >> n;
if (ss.fail()) {
// To mimic `std::stoull` and to avoid including `Exception.h`, throw
// `std::invalid_argument`.
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
throw std::invalid_argument("Not an unsigned 64-bit integer");
}
return n;
}
@ -45,6 +58,12 @@ inline double stod(const std::string& str, std::size_t* pos = 0) {
ss << str;
double val = 0;
ss >> val;
if (ss.fail()) {
// To mimic `std::stod` and to avoid including `Exception.h`, throw
// `std::invalid_argument`.
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
throw std::invalid_argument("Not a double-precision floating point number");
}
if (pos) {
if (ss.tellg() == std::streampos(-1)) {
*pos = str.size();
@ -62,6 +81,12 @@ inline long long stoll(const std::string& str, std::size_t* pos = 0) {
ss << str;
long long result = 0;
ss >> result;
if (ss.fail()) {
// To mimic `std::stoll` and to avoid including `Exception.h`, throw
// `std::invalid_argument`.
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
throw std::invalid_argument("Not a long long integer");
}
if (pos) {
if (ss.tellg() == std::streampos(-1)) {
*pos = str.size();
@ -75,18 +100,24 @@ inline long long stoll(const std::string& str, std::size_t* pos = 0) {
inline long long stoll(const std::string& str, size_t pos, int base) {
// std::stoll doesn't exist in our Android environment, we need to implement
// it ourselves.
std::stringstream s;
std::stringstream ss;
if (str.size() > 0 && str.at(0) == '0') {
if (str.size() > 1 && (str.at(1) == 'x' || str.at(1) == 'X')) {
s << std::hex << str;
ss << std::hex << str;
} else {
s << std::oct << str;
ss << std::oct << str;
}
} else {
s << str;
ss << str;
}
long long result = 0;
s >> result;
ss >> result;
if (ss.fail()) {
// To mimic `std::stoll` and to avoid including `Exception.h`, throw
// `std::invalid_argument`.
// We can't easily detect out-of-range, so we don't use `std::out_of_range`.
throw std::invalid_argument("Not a long long integer");
}
return result;
}

View File

@ -853,7 +853,7 @@ class TestCuda(TestCase):
import os
fname = "tempfile.pt"
try:
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
torch.save([torch.nn.Parameter(torch.randn(10, 10))], fname,
_use_new_zipfile_serialization=True)
torch.load(fname, 'cuda0')
@ -863,7 +863,7 @@ class TestCuda(TestCase):
def test_get_device_index(self):
from torch.cuda._utils import _get_device_index
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
_get_device_index('cuda0', optional=True)
with self.assertRaisesRegex(ValueError, "Expected a cuda device"):

View File

@ -790,11 +790,32 @@ class _TestTorchMixin(object):
self.assertEqual('cuda', cuda1.type)
self.assertEqual(1, cuda1.index)
cuda90 = torch.device('cuda', 90)
self.assertEqual('cuda:90', str(cuda90))
self.assertEqual('cuda', cuda90.type)
self.assertEqual(90, cuda90.index)
cuda23333 = torch.device('cuda', 23333)
self.assertEqual('cuda:23333', str(cuda23333))
self.assertEqual('cuda', cuda23333.type)
self.assertEqual(23333, cuda23333.index)
self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1'))
self.assertRaises(RuntimeError, lambda: torch.device('cpu:1'))
self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1))
self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 '))
self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 2'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2?'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:?2'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.232'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3'))
self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1))
self.assertRaises(RuntimeError, lambda: torch.device(-1))