mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 11:14:56 +08:00
Cache cufft plans (#8344)
* cache cufft plans * use an LRU cache * suffix CuFFTParams members with _ * import print_function for py2 * lint * fix potential race; add dummy impl for CPU only builds * cpp formatting; remove nccl makefile change * Use CUDA hooks instead * comments and doc * update the error message * move LRU cachae to a separate file and native::detail namespace * update comment * specify NOTE location in CuFFTPlanCache.h * update disabled_features.yaml to make amd ci work * another fix for AMD CI in disabled_features.yaml * Wrap cufft_plan_cache_* methods in __HIP_PLATFORM_HCC__ * improve the notes * lint * revert onnx change * put back inlining for CUFFT_CHECK
This commit is contained in:
@ -112,7 +112,7 @@ SUPPORTED_RETURN_TYPES = {
|
||||
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
|
||||
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
|
||||
'std::vector<Tensor>',
|
||||
'Scalar', 'bool', 'int64_t', 'void*'
|
||||
'Scalar', 'bool', 'int64_t', 'void*', 'void'
|
||||
}
|
||||
|
||||
TENSOR_OPTIONS = CodeTemplate("""\
|
||||
@ -436,7 +436,11 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||
if requires_grad and not has_tensor_options:
|
||||
call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
|
||||
requires_grad=requires_grad)
|
||||
body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
|
||||
if simple_return_type == 'void':
|
||||
body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
|
||||
body.append('Py_RETURN_NONE;')
|
||||
else:
|
||||
body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
|
||||
py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
|
||||
return body
|
||||
|
||||
|
||||
Reference in New Issue
Block a user