Fix ordered_dict.h for CUDA on Windows (#55275)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/55266

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55275

Reviewed By: mrshenli

Differential Revision: D27623887

Pulled By: malfet

fbshipit-source-id: 6dac357e21179a259ac95f0e1b7399b03dacc81d
This commit is contained in:
peter
2021-04-07 23:41:44 -07:00
committed by Facebook GitHub Bot
parent 0dff0d1537
commit 3517ee1bcb
3 changed files with 18 additions and 11 deletions

View File

@ -4,7 +4,6 @@ import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from torch.testing._internal.common_utils import IS_WINDOWS
if sys.platform == 'win32':
vc_version = os.getenv('VCToolsVersion', '')
@ -40,15 +39,14 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None
'nvcc': ['-O2']})
ext_modules.append(extension)
if not IS_WINDOWS: # MSVC has bug compiling this example
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
extension = CUDAExtension(
'torch_test_cpp_extension.torch_library', [
'torch_library.cu'
],
extra_compile_args={'cxx': CXX_FLAGS,
'nvcc': ['-O2']})
ext_modules.append(extension)
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
extension = CUDAExtension(
'torch_test_cpp_extension.torch_library', [
'torch_library.cu'
],
extra_compile_args={'cxx': CXX_FLAGS,
'nvcc': ['-O2']})
ext_modules.append(extension)
setup(
name='torch_test_cpp_extension',

View File

@ -186,7 +186,6 @@ class TestRNGExtension(common.TestCase):
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(IS_WINDOWS, "MSVC have bug compiling this")
class TestTorchLibrary(common.TestCase):
def test_torch_library(self):

View File

@ -193,6 +193,16 @@ class OrderedDict<Key, Value>::Item {
/// Constructs a new item.
Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {}
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ < 11) && defined(_MSC_VER)
/// Related issue: https://github.com/pytorch/pytorch/issues/55266
/// Needs to define this function for CUDA < 11.0 on Windows,
/// although it usually won't be used actually.
Item& operator=(const Item& other) {
pair_ = other.pair_;
return *this;
}
#endif
/// Returns a reference to the value.
Value& operator*() {
return value();