mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
I've found that when using `torch.utils.cpp_extension.load` on my Windows system, decoding errors occur when my .cpp/.cu files contain certain non-English characters.
`test.py`:
```py
from torch.utils.cpp_extension import load
my_lib = load(name='my_cuda_kernel', sources=['my_cuda_kernel.cu'], extra_cuda_cflags=['-O2', '-std=c++17'])
# ......
```
`my_cuda_kernel.cu`:
```cpp
#include <torch/types.h>
#include <torch/extension.h>
// 向量化 <------ some chinese characters
// ......
```
Errors will be reported as:
```
Traceback (most recent call last):
File "E:\test\test.py", line 8, in <module>
my_lib = load(
^^^^^
File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\cpp_extension.py", line 1314, in load
return _jit_compile(
^^^^^^^^^^^^^
File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\cpp_extension.py", line 1680, in _jit_compile
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\_cpp_extension_versioner.py", line 46, in bump_version_if_changed
hash_value = hash_source_files(hash_value, source_files)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\XXX\AppData\Roaming\Python\Python311\site-packages\torch\utils\_cpp_extension_versioner.py", line 17, in hash_source_files
hash_value = update_hash(hash_value, file.read())
^^^^^^^^^^^
UnicodeDecodeError: 'gbk' codec can't decode byte 0x96 in position 141: illegal multibyte sequence
```
The issue lies in the fact that the `open()` function in Python is platform-dependent, which can cause decoding errors when a file contains characters that are not supported by the default encoding. Pytorch uses file contents to generate hash string:
60c1433041/torch/utils/_cpp_extension_versioner.py (L16-L17)
In my windows the default encoding is `gbk` but all of my cpp files are in `utf-8`.
There is a simple solution to this problem I think: just change the file reading mode to binary mode, which can avoid issues related to file encoding. It works perfectly on my computer.
```diff
- with open(filename) as file:
+ with open(filename, 'rb') as file:
hash_value = update_hash(hash_value, file.read())
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138565
Approved by: https://github.com/malfet, https://github.com/janeyx99
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
# mypy: allow-untyped-defs
|
|
import collections
|
|
|
|
|
|
Entry = collections.namedtuple('Entry', 'version, hash')
|
|
|
|
|
|
def update_hash(seed, value):
|
|
# Good old boost::hash_combine
|
|
# https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
|
|
return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2))
|
|
|
|
|
|
def hash_source_files(hash_value, source_files):
|
|
for filename in source_files:
|
|
with open(filename, 'rb') as file:
|
|
hash_value = update_hash(hash_value, file.read())
|
|
return hash_value
|
|
|
|
|
|
def hash_build_arguments(hash_value, build_arguments):
|
|
for group in build_arguments:
|
|
if group:
|
|
for argument in group:
|
|
hash_value = update_hash(hash_value, argument)
|
|
return hash_value
|
|
|
|
|
|
class ExtensionVersioner:
|
|
def __init__(self):
|
|
self.entries = {}
|
|
|
|
def get_version(self, name):
|
|
entry = self.entries.get(name)
|
|
return None if entry is None else entry.version
|
|
|
|
def bump_version_if_changed(self,
|
|
name,
|
|
source_files,
|
|
build_arguments,
|
|
build_directory,
|
|
with_cuda,
|
|
is_python_module,
|
|
is_standalone):
|
|
hash_value = 0
|
|
hash_value = hash_source_files(hash_value, source_files)
|
|
hash_value = hash_build_arguments(hash_value, build_arguments)
|
|
hash_value = update_hash(hash_value, build_directory)
|
|
hash_value = update_hash(hash_value, with_cuda)
|
|
hash_value = update_hash(hash_value, is_python_module)
|
|
hash_value = update_hash(hash_value, is_standalone)
|
|
|
|
entry = self.entries.get(name)
|
|
if entry is None:
|
|
self.entries[name] = entry = Entry(0, hash_value)
|
|
elif hash_value != entry.hash:
|
|
self.entries[name] = entry = Entry(entry.version + 1, hash_value)
|
|
|
|
return entry.version
|