mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use Unicode friendly API on Win32 in THAllocator (#47905)
Summary: This replaces the narrow character set APIs with the wide character set ones in `THAllocator.cpp`. This fixes the potential crashes caused by passing non-ASCII characters in `torch::from_file` on Windows. See: https://github.com/pytorch/pytorch/issues/47422 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47905 Reviewed By: zhangguanheng66 Differential Revision: D25399146 Pulled By: ezyang fbshipit-source-id: 0a183b65de171c48ed1718fa71e773224eaf196f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1e2d1d7242
commit
3a943e9f82
@ -11,6 +11,8 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
|
||||
|
||||
echo "Testing pytorch"
|
||||
|
||||
export LANG=C.UTF-8
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then
|
||||
export PYTORCH_TEST_WITH_SLOW=1
|
||||
export PYTORCH_TEST_SKIP_FAST=1
|
||||
|
@ -6,6 +6,7 @@
|
||||
#endif
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/Unicode.h>
|
||||
|
||||
/* stuff for mapped files */
|
||||
#ifdef _WIN32
|
||||
@ -74,24 +75,26 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
|
||||
#ifdef _WIN32
|
||||
if (flags_ & TH_ALLOCATOR_MAPPED_SHAREDMEM) {
|
||||
// Shadowing
|
||||
const char *filename;
|
||||
const char *eventname;
|
||||
const wchar_t *filename;
|
||||
const wchar_t *eventname;
|
||||
const std::wstring wFilename = c10::u8u16(filename_);
|
||||
const std::wstring wEventname = c10::u8u16(eventname_);
|
||||
LARGE_INTEGER hfilesz;
|
||||
|
||||
if (filename_[0] == '/') {
|
||||
filename = filename_.c_str() + 1;
|
||||
eventname = eventname_.c_str() + 1;
|
||||
filename = wFilename.c_str() + 1;
|
||||
eventname = wEventname.c_str() + 1;
|
||||
} else {
|
||||
filename = filename_.c_str();
|
||||
eventname = eventname_.c_str();
|
||||
filename = wFilename.c_str();
|
||||
eventname = wEventname.c_str();
|
||||
}
|
||||
|
||||
hfilesz.QuadPart = size;
|
||||
|
||||
if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) {
|
||||
event_ = CreateEvent(nullptr, FALSE, FALSE, eventname);
|
||||
event_ = CreateEventW(nullptr, FALSE, FALSE, eventname);
|
||||
} else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) {
|
||||
event_ = OpenEvent(EVENT_ALL_ACCESS, FALSE, eventname);
|
||||
event_ = OpenEventW(EVENT_ALL_ACCESS, FALSE, eventname);
|
||||
} else {
|
||||
AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE");
|
||||
}
|
||||
@ -101,9 +104,9 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
|
||||
}
|
||||
|
||||
if (flags_ & TH_ALLOCATOR_MAPPED_EXCLUSIVE) {
|
||||
handle_ = CreateFileMapping(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename);
|
||||
handle_ = CreateFileMappingW(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename);
|
||||
} else if (flags_ & TH_ALLOCATOR_MAPPED_NOCREATE) {
|
||||
handle_ = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, filename);
|
||||
handle_ = OpenFileMappingW(FILE_MAP_ALL_ACCESS, FALSE, filename);
|
||||
} else {
|
||||
AT_ERROR("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE");
|
||||
}
|
||||
@ -136,15 +139,21 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
|
||||
AT_ERROR("TH_ALLOCATOR_MAPPED_FROMFD not supported on Windows");
|
||||
}
|
||||
|
||||
// Shadowing
|
||||
const wchar_t *filename;
|
||||
const std::wstring wFilename = c10::u8u16(filename_);
|
||||
|
||||
filename = wFilename.c_str();
|
||||
|
||||
/* open file */
|
||||
/* FILE_FLAG_RANDOM_ACCESS ? */
|
||||
if (flags_) {
|
||||
hfile = CreateFileA(filename_.c_str(), GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
|
||||
hfile = CreateFileW(filename, GENERIC_READ|GENERIC_WRITE, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
|
||||
if (hfile == INVALID_HANDLE_VALUE) {
|
||||
AT_ERROR("could not open file <", filename_, "> in read-write mode; error code: <", GetLastError(), ">");
|
||||
}
|
||||
} else {
|
||||
hfile = CreateFileA(filename_.c_str(), GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
|
||||
hfile = CreateFileW(filename, GENERIC_READ, FILE_SHARE_WRITE|FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
|
||||
if (hfile == INVALID_HANDLE_VALUE) {
|
||||
AT_ERROR("could not open file <", filename_, "> in read-only mode; error code: <", GetLastError(), ">");
|
||||
}
|
||||
@ -181,11 +190,11 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags,
|
||||
|
||||
/* get map handle */
|
||||
if (flags_) {
|
||||
if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
|
||||
if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
|
||||
AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">");
|
||||
}
|
||||
} else {
|
||||
if ( (hmfile = CreateFileMapping(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
|
||||
if ( (hmfile = CreateFileMappingW(hfile, NULL, PAGE_WRITECOPY, hfilesz.HighPart, hfilesz.LowPart, NULL)) == NULL ) {
|
||||
AT_ERROR("could not create a map on file <", filename_, ">; error code: <", GetLastError(), ">");
|
||||
}
|
||||
}
|
||||
|
29
c10/util/Unicode.h
Normal file
29
c10/util/Unicode.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <string>
|
||||
#include <c10/util/win32-headers.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
#if defined(_WIN32)
|
||||
inline std::wstring u8u16(const std::string& str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
int size_needed = MultiByteToWideChar(
|
||||
CP_UTF8, 0, str.c_str(), static_cast<int>(str.size()), NULL, 0);
|
||||
TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode");
|
||||
std::wstring wstr(size_needed, 0);
|
||||
MultiByteToWideChar(
|
||||
CP_UTF8,
|
||||
0,
|
||||
str.c_str(),
|
||||
static_cast<int>(str.size()),
|
||||
&wstr[0],
|
||||
size_needed);
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
}
|
@ -19,10 +19,10 @@ from itertools import product, combinations, permutations
|
||||
from torch import multiprocessing as mp
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, TEST_WITH_ROCM, run_tests,
|
||||
IS_WINDOWS, NO_MULTIPROCESSING_SPAWN,
|
||||
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
||||
do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
|
||||
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext,
|
||||
skipIfRocm, skipIfNoSciPy,
|
||||
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
|
||||
wrapDeterministicFlagAPITest, DeterministicGuard)
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -1852,15 +1852,14 @@ class AbstractTestCases:
|
||||
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
|
||||
self.assertIs(complexdouble_storage.dtype, torch.complex128)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
|
||||
def test_from_file(self):
|
||||
def assert_with_filename(filename):
|
||||
size = 10000
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
s1 = torch.FloatStorage.from_file(f.name, True, size)
|
||||
s1 = torch.FloatStorage.from_file(filename, True, size)
|
||||
t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
|
||||
|
||||
# check mapping
|
||||
s2 = torch.FloatStorage.from_file(f.name, True, size)
|
||||
s2 = torch.FloatStorage.from_file(filename, True, size)
|
||||
t2 = torch.FloatTensor(s2)
|
||||
self.assertEqual(t1, t2, atol=0, rtol=0)
|
||||
|
||||
@ -1874,15 +1873,24 @@ class AbstractTestCases:
|
||||
t2.fill_(rnum)
|
||||
self.assertEqual(t1, t2, atol=0, rtol=0)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
|
||||
# release the tensors
|
||||
del s1, t1, s2, t2
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
assert_with_filename(fname)
|
||||
|
||||
if IS_FILESYSTEM_UTF8_ENCODING:
|
||||
with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname:
|
||||
assert_with_filename(fname)
|
||||
|
||||
def test_torch_from_file(self):
|
||||
def assert_with_filename(filename):
|
||||
size = 10000
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
s1 = torch.from_file(f.name, True, size, dtype=torch.float)
|
||||
s1 = torch.from_file(filename, True, size, dtype=torch.float)
|
||||
t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
|
||||
|
||||
# check mapping
|
||||
s2 = torch.from_file(f.name, True, size, dtype=torch.float)
|
||||
s2 = torch.from_file(filename, True, size, dtype=torch.float)
|
||||
t2 = torch.FloatTensor(s2)
|
||||
self.assertEqual(t1, t2, atol=0, rtol=0)
|
||||
|
||||
@ -1896,6 +1904,16 @@ class AbstractTestCases:
|
||||
t2.fill_(rnum)
|
||||
self.assertEqual(t1, t2, atol=0, rtol=0)
|
||||
|
||||
# release the tensors
|
||||
del s1, t1, s2, t2
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
assert_with_filename(fname)
|
||||
|
||||
if IS_FILESYSTEM_UTF8_ENCODING:
|
||||
with TemporaryDirectoryName(suffix='中文') as dname, TemporaryFileName(dir=dname) as fname:
|
||||
assert_with_filename(fname)
|
||||
|
||||
def test_print(self):
|
||||
default_type = torch.Tensor().type()
|
||||
for t in torch._tensor_classes:
|
||||
|
@ -21,6 +21,7 @@ import unittest
|
||||
import warnings
|
||||
import random
|
||||
import contextlib
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
@ -300,11 +301,11 @@ IS_PPC = platform.machine() == "ppc64le"
|
||||
|
||||
if IS_WINDOWS:
|
||||
@contextmanager
|
||||
def TemporaryFileName():
|
||||
def TemporaryFileName(dir=None):
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dir)
|
||||
try:
|
||||
f.close()
|
||||
yield f.name
|
||||
@ -312,10 +313,27 @@ if IS_WINDOWS:
|
||||
os.unlink(f.name)
|
||||
else:
|
||||
@contextmanager # noqa: T484
|
||||
def TemporaryFileName():
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
def TemporaryFileName(dir=None):
|
||||
with tempfile.NamedTemporaryFile(dir=dir) as f:
|
||||
yield f.name
|
||||
|
||||
if IS_WINDOWS:
|
||||
@contextmanager
|
||||
def TemporaryDirectoryName(suffix=None):
|
||||
# On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
|
||||
# so we first create the directory using mkdtemp and then remove it manually
|
||||
try:
|
||||
dir_name = tempfile.mkdtemp(suffix=suffix)
|
||||
yield dir_name
|
||||
finally:
|
||||
shutil.rmtree(dir_name)
|
||||
else:
|
||||
@contextmanager # noqa: T484
|
||||
def TemporaryDirectoryName(suffix=None):
|
||||
with tempfile.TemporaryDirectory(suffix=suffix) as d:
|
||||
yield d
|
||||
|
||||
IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
|
||||
|
||||
def _check_module_exists(name):
|
||||
r"""Returns if a top-level module with :attr:`name` exists *without**
|
||||
|
Reference in New Issue
Block a user