Use Unicode friendly API on Win32 in THAllocator (#47905)

Summary:
This replaces the narrow character set APIs with the wide character set ones in `THAllocator.cpp`. This fixes the potential crashes caused by passing non-ASCII characters in `torch::from_file` on Windows.

See: https://github.com/pytorch/pytorch/issues/47422

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47905

Reviewed By: zhangguanheng66

Differential Revision: D25399146

Pulled By: ezyang

fbshipit-source-id: 0a183b65de171c48ed1718fa71e773224eaf196f
This commit is contained in:
Chester Liu
2020-12-14 14:21:32 -08:00
committed by Facebook GitHub Bot
parent 1e2d1d7242
commit 3a943e9f82
5 changed files with 106 additions and 30 deletions

View File

@ -11,6 +11,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
echo "Testing pytorch"
export LANG=C.UTF-8
if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then
export PYTORCH_TEST_WITH_SLOW=1
export PYTORCH_TEST_SKIP_FAST=1

View File

@ -6,6 +6,7 @@
#endif
#include <c10/core/CPUAllocator.h>
#include <c10/util/Unicode.h>
/* stuff for mapped files */
#ifdef _WIN32
@ -74,24 +75,26 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
#ifdef _WIN32
if (flags_ & TH_ALLOCATOR_MAPPED_SHAREDMEM) {
// Shadowing
const char *filename;
const char *eventname;
const wchar_t *filename;
const wchar_t *eventname;
const std::wstring wFilename = c10::u8u16(filename_);
const std::wstring wEventname = c10::u8u16(eventname_);
LARGE_INTEGER hfilesz;
if (filename_[0] == '/') {
filename = filename_.c_str() + 1;
eventname = eventname_.c_str() + 1;
filename = wFilename.c_str() + 1;
eventname = wEventname.c_str() + 1;
} else {
filename = filename_.c_str();
eventname = eventname_.c_str();
filename = wFilename.c_str();
eventname = wEventname.c_str();
}
hfilesz.QuadPart = size;
if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) {
event_ = CreateEvent(nullptr, FALSE, FALSE, eventname);
event_ = CreateEventW(nullptr, FALSE, FALSE, eventname);
} else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) {
event_ = OpenEvent(EVENT_ALL_ACCESS, FALSE, eventname);
event_ = OpenEventW(EVENT_ALL_ACCESS, FALSE, eventname);
} else {
AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE");
}
@ -101,9 +104,9 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
}
if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) {
handle_ = CreateFileMapping(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename);
handle_ = CreateFileMappingW(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename);
} else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) {
handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, filename);
handle_ = OpenFileMappingW(FILE_MAP_ALL_ACCESS, FALSE, filename);
} else {
AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE");
}
@ -136,15 +139,21 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
AT_ERROR("TH_ALLOCATOR_MAPPED_FROMFD not supported on Windows");
}
// Shadowing
const wchar_t *filename;
const std::wstring wFilename = c10::u8u16(filename_);
filename = wFilename.c_str();
/* open file */
/* FILE_FLAG_RANDOM_ACCESS ? */
if (flags_) {
hfile = CreateFileA(filename_.c_str(), GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
hfile = CreateFileW(filename, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
if (hfile == INVALID_HANDLE_VALUE) {
AT_ERROR("could not open file <", filename_, "> in read-write mode; error code: <", GetLastError(), ">");
}
} else {
hfile = CreateFileA(filename_.c_str(), GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
hfile = CreateFileW(filename, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
if (hfile == INVALID_HANDLE_VALUE) {
AT_ERROR("could not open file <", filename_, "> in read-only mode; error code: <", GetLastError(), ">");
}
@ -181,11 +190,11 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
/* get map handle */
if (flags_) {
if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">");
}
} else {
if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">");
}
}

29
c10/util/Unicode.h Normal file
View File

@ -0,0 +1,29 @@
#pragma once
#if defined(_WIN32)
#include <string>
#include <c10/util/win32-headers.h>
#include <c10/util/Exception.h>
#endif
namespace c10 {
#if defined(_WIN32)
inline std::wstring u8u16(const std::string& str) {
if (str.empty()) {
return std::wstring();
}
int size_needed = MultiByteToWideChar(
CP_UTF8, 0, str.c_str(), static_cast<int>(str.size()), NULL, 0);
TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode");
std::wstring wstr(size_needed, 0);
MultiByteToWideChar(
CP_UTF8,
0,
str.c_str(),
static_cast<int>(str.size()),
&wstr[0],
size_needed);
return wstr;
}
#endif
}

View File

@ -19,10 +19,10 @@ from itertools import product, combinations, permutations
from torch import multiprocessing as mp
from torch.testing._internal.common_utils import (
TestCase, TEST_WITH_ROCM, run_tests,
IS_WINDOWS, NO_MULTIPROCESSING_SPAWN,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, skipIfNoSciPy,
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
wrapDeterministicFlagAPITest, DeterministicGuard)
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import (
@ -1852,15 +1852,14 @@ class AbstractTestCases:
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
def test_from_file(self):
size = 10000
with tempfile.NamedTemporaryFile() as f:
s1 = torch.FloatStorage.from_file(f.name, True, size)
def assert_with_filename(filename):
size = 10000
s1 = torch.FloatStorage.from_file(filename, True, size)
t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
# check mapping
s2 = torch.FloatStorage.from_file(f.name, True, size)
s2 = torch.FloatStorage.from_file(filename, True, size)
t2 = torch.FloatTensor(s2)
self.assertEqual(t1, t2, atol=0, rtol=0)
@ -1874,15 +1873,24 @@ class AbstractTestCases:
t2.fill_(rnum)
self.assertEqual(t1, t2, atol=0, rtol=0)
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
# release the tensors
del s1, t1, s2, t2
with TemporaryFileName() as fname:
assert_with_filename(fname)
if IS_FILESYSTEM_UTF8_ENCODING:
with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname:
assert_with_filename(fname)
def test_torch_from_file(self):
size = 10000
with tempfile.NamedTemporaryFile() as f:
s1 = torch.from_file(f.name, True, size, dtype=torch.float)
def assert_with_filename(filename):
size = 10000
s1 = torch.from_file(filename, True, size, dtype=torch.float)
t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
# check mapping
s2 = torch.from_file(f.name, True, size, dtype=torch.float)
s2 = torch.from_file(filename, True, size, dtype=torch.float)
t2 = torch.FloatTensor(s2)
self.assertEqual(t1, t2, atol=0, rtol=0)
@ -1896,6 +1904,16 @@ class AbstractTestCases:
t2.fill_(rnum)
self.assertEqual(t1, t2, atol=0, rtol=0)
# release the tensors
del s1, t1, s2, t2
with TemporaryFileName() as fname:
assert_with_filename(fname)
if IS_FILESYSTEM_UTF8_ENCODING:
with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname:
assert_with_filename(fname)
def test_print(self):
default_type = torch.Tensor().type()
for t in torch._tensor_classes:

View File

@ -21,6 +21,7 @@ import unittest
import warnings
import random
import contextlib
import shutil
import socket
import subprocess
import time
@ -300,11 +301,11 @@ IS_PPC = platform.machine() == "ppc64le"
if IS_WINDOWS:
@contextmanager
def TemporaryFileName():
def TemporaryFileName(dir=None):
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
f = tempfile.NamedTemporaryFile(delete=False, dir=dir)
try:
f.close()
yield f.name
@ -312,10 +313,27 @@ if IS_WINDOWS:
os.unlink(f.name)
else:
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
def TemporaryFileName(dir=None):
with tempfile.NamedTemporaryFile(dir=dir) as f:
yield f.name
if IS_WINDOWS:
@contextmanager
def TemporaryDirectoryName(suffix=None):
# On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
# so we first create the directory using mkdtemp and then remove it manually
try:
dir_name = tempfile.mkdtemp(suffix=suffix)
yield dir_name
finally:
shutil.rmtree(dir_name)
else:
@contextmanager # noqa: T484
def TemporaryDirectoryName(suffix=None):
with tempfile.TemporaryDirectory(suffix=suffix) as d:
yield d
IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
def _check_module_exists(name):
r"""Returns if a top-level module with :attr:`name` exists *without**