mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix ordered_dict.h for CUDA on Windows (#55275)
Summary: Fixes https://github.com/pytorch/pytorch/issues/55266 Pull Request resolved: https://github.com/pytorch/pytorch/pull/55275 Reviewed By: mrshenli Differential Revision: D27623887 Pulled By: malfet fbshipit-source-id: 6dac357e21179a259ac95f0e1b7399b03dacc81d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0dff0d1537
commit
3517ee1bcb
@ -4,7 +4,6 @@ import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
if sys.platform == 'win32':
|
||||
vc_version = os.getenv('VCToolsVersion', '')
|
||||
@ -40,15 +39,14 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
if not IS_WINDOWS: # MSVC has bug compiling this example
|
||||
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.torch_library', [
|
||||
'torch_library.cu'
|
||||
],
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.torch_library', [
|
||||
'torch_library.cu'
|
||||
],
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
setup(
|
||||
name='torch_test_cpp_extension',
|
||||
|
@ -186,7 +186,6 @@ class TestRNGExtension(common.TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(IS_WINDOWS, "MSVC have bug compiling this")
|
||||
class TestTorchLibrary(common.TestCase):
|
||||
|
||||
def test_torch_library(self):
|
||||
|
@ -193,6 +193,16 @@ class OrderedDict<Key, Value>::Item {
|
||||
/// Constructs a new item.
|
||||
Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {}
|
||||
|
||||
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ < 11) && defined(_MSC_VER)
|
||||
/// Related issue: https://github.com/pytorch/pytorch/issues/55266
|
||||
/// Needs to define this function for CUDA < 11.0 on Windows,
|
||||
/// although it usually won't be used actually.
|
||||
Item& operator=(const Item& other) {
|
||||
pair_ = other.pair_;
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Returns a reference to the value.
|
||||
Value& operator*() {
|
||||
return value();
|
||||
|
Reference in New Issue
Block a user