mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
import warnings
|
|
from .._jit_internal import weak_script
|
|
|
|
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
|
|
|
|
|
|
@weak_script
|
|
def get_enum(reduction):
|
|
# type: (str) -> int
|
|
if reduction == 'none':
|
|
ret = 0
|
|
elif reduction == 'mean':
|
|
ret = 1
|
|
elif reduction == 'elementwise_mean':
|
|
warnings.warn("reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.")
|
|
ret = 1
|
|
elif reduction == 'sum':
|
|
ret = 2
|
|
else:
|
|
ret = -1 # TODO: remove once JIT exceptions support control flow
|
|
raise ValueError(reduction + " is not a valid value for reduction")
|
|
return ret
|
|
|
|
# In order to support previous versions, accept boolean size_average and reduce
|
|
# and convert them into the new constants for now
|
|
|
|
|
|
# We use these functions in torch/legacy as well, in which case we'll silence the warning
|
|
@weak_script
|
|
def legacy_get_string(size_average, reduce, emit_warning=True):
|
|
# type: (Optional[bool], Optional[bool], bool) -> str
|
|
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
|
|
|
|
if size_average is None:
|
|
size_average = True
|
|
if reduce is None:
|
|
reduce = True
|
|
|
|
if size_average and reduce:
|
|
ret = 'mean'
|
|
elif reduce:
|
|
ret = 'sum'
|
|
else:
|
|
ret = 'none'
|
|
if emit_warning:
|
|
warnings.warn(warning.format(ret))
|
|
return ret
|
|
|
|
|
|
@weak_script
|
|
def legacy_get_enum(size_average, reduce, emit_warning=True):
|
|
# type: (Optional[bool], Optional[bool], bool) -> int
|
|
return get_enum(legacy_get_string(size_average, reduce, emit_warning))
|