mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[1/N] Use "is" in python type comparison (#165037)
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
960b0d5f0d
commit
70925bdf82
@ -2278,7 +2278,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
|
||||
val = self.get(guard.name)
|
||||
# Strictly only want user-defined functions
|
||||
if type(val) == types.FunctionType and hasattr(val, "__code__"):
|
||||
if type(val) is types.FunctionType and hasattr(val, "__code__"):
|
||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
|
||||
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
|
||||
else:
|
||||
|
@ -197,7 +197,7 @@ class ConstDictVariable(VariableTracker):
|
||||
@staticmethod
|
||||
def _eq_impl(a, b):
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
if type(a) != type(b):
|
||||
if type(a) is not type(b):
|
||||
return False
|
||||
elif isinstance(a, tuple):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
|
@ -3532,7 +3532,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
||||
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
|
||||
|
||||
# check is important to avoid subclass dispatch
|
||||
if type(value) != type(_local_map_wrapped):
|
||||
if type(value) is not type(_local_map_wrapped):
|
||||
return False
|
||||
|
||||
return value == _local_map_wrapped and cls._enabled
|
||||
|
@ -200,7 +200,7 @@ class BaseListVariable(VariableTracker):
|
||||
if kwargs or len(args) != 1:
|
||||
raise_args_mismatch(tx, name)
|
||||
|
||||
if type(self) != type(args[0]):
|
||||
if type(self) is not type(args[0]):
|
||||
tp_name = self.python_type_name()
|
||||
other = args[0].python_type_name()
|
||||
msg = ConstantVariable.create(
|
||||
|
@ -464,7 +464,7 @@ def _check_input_constraints_for_graph(
|
||||
)
|
||||
|
||||
elif isinstance(node_val, (int, float, str)):
|
||||
if type(arg) != type(node_val) or arg != node_val:
|
||||
if type(arg) is not type(node_val) or arg != node_val:
|
||||
raise RuntimeError(
|
||||
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
|
||||
)
|
||||
|
@ -361,7 +361,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
|
||||
input_meta += get_input_meta(args[1])
|
||||
return input_meta
|
||||
for arg in args:
|
||||
if type(arg) == int or type(arg) == float:
|
||||
if type(arg) is int or type(arg) is float:
|
||||
input_meta.append((type(arg),))
|
||||
else:
|
||||
input_meta.append(
|
||||
|
@ -802,7 +802,7 @@ def _insert_fused_matmul_reduce_scatter(
|
||||
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
|
||||
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
|
||||
) -> torch.fx.Node:
|
||||
if type(matmul) == _Matmul:
|
||||
if type(matmul) is _Matmul:
|
||||
return graph.call_function(
|
||||
torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
|
||||
args=(
|
||||
@ -813,7 +813,7 @@ def _insert_fused_matmul_reduce_scatter(
|
||||
group_name,
|
||||
),
|
||||
)
|
||||
elif type(matmul) == _ScaledMatmul:
|
||||
elif type(matmul) is _ScaledMatmul:
|
||||
return graph.call_function(
|
||||
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
|
||||
args=(
|
||||
|
@ -98,7 +98,7 @@ class FakeTensorUpdater:
|
||||
return statically_known_true(sym_eq(new, old))
|
||||
|
||||
def is_fake_tensor_same(new, old, *, node):
|
||||
if type(new) != type(old):
|
||||
if type(new) is not type(old):
|
||||
return False
|
||||
if isinstance(new, (list, tuple)):
|
||||
if len(new) != len(old):
|
||||
|
@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
|
||||
def parse_metadata(e):
|
||||
if isinstance(e, torch.Tensor):
|
||||
if type(e) != torch.Tensor:
|
||||
if type(e) is not torch.Tensor:
|
||||
try:
|
||||
current = e.elem
|
||||
return (
|
||||
|
@ -37,7 +37,7 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = _import_dotted_name(dtype)
|
||||
if dtype == type(self):
|
||||
if dtype is type(self):
|
||||
return self
|
||||
if self.is_sparse:
|
||||
if not dtype.is_sparse:
|
||||
|
@ -80,7 +80,7 @@ class ConvReLU1d(nnq.Conv1d):
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
||||
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
mod.weight,
|
||||
@ -161,7 +161,7 @@ class ConvReLU2d(nnq.Conv2d):
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
||||
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
mod.weight,
|
||||
@ -244,7 +244,7 @@ class ConvReLU3d(nnq.Conv3d):
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
||||
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
mod.weight,
|
||||
|
@ -117,7 +117,7 @@ class Linear(nnq.Linear):
|
||||
+ str([float_mod.__name__ for float_mod in float_modules])
|
||||
)
|
||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||
if type(mod) == nni.LinearReLU:
|
||||
if type(mod) is nni.LinearReLU:
|
||||
mod = mod[0]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
|
@ -1064,15 +1064,15 @@ class RNNCellBase(torch.nn.Module):
|
||||
|
||||
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
|
||||
|
||||
if type(mod) == torch.nn.LSTMCell:
|
||||
if type(mod) is torch.nn.LSTMCell:
|
||||
qRNNCellBase = LSTMCell(
|
||||
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
||||
)
|
||||
elif type(mod) == torch.nn.GRUCell:
|
||||
elif type(mod) is torch.nn.GRUCell:
|
||||
qRNNCellBase = GRUCell(
|
||||
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
|
||||
)
|
||||
elif type(mod) == torch.nn.RNNCell:
|
||||
elif type(mod) is torch.nn.RNNCell:
|
||||
qRNNCellBase = RNNCell(
|
||||
mod.input_size,
|
||||
mod.hidden_size,
|
||||
|
@ -20,7 +20,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
||||
@staticmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
activation_post_process = mod.activation_post_process
|
||||
if type(mod) == cls._NNI_BN_RELU_MODULE:
|
||||
if type(mod) is cls._NNI_BN_RELU_MODULE:
|
||||
mod = mod[0]
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
new_mod = cls(mod.num_features, mod.eps)
|
||||
|
@ -280,7 +280,7 @@ class _ConvNd(WeightedQuantizedModule):
|
||||
if hasattr(mod, "weight_fake_quant"):
|
||||
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
|
||||
# ".from_float only works for " + cls.__QAT_MODULE.__name__
|
||||
if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
|
||||
if type(mod) is cls._NNIQAT_CONV_BN_MODULE:
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
mod.weight,
|
||||
mod.bias,
|
||||
|
@ -149,7 +149,7 @@ class Linear(torch.nn.Module):
|
||||
# TODO: Need to add options to qconfig to avoid the calibration.
|
||||
# TODO: Add calibration for the sparsity
|
||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||
if type(mod) == nni.LinearReLU:
|
||||
if type(mod) is nni.LinearReLU:
|
||||
mod = mod[0]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
|
@ -268,7 +268,7 @@ def extract_weight_from_node(
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
|
||||
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
||||
if type(mod) == target_mod_type:
|
||||
if type(mod) is target_mod_type:
|
||||
weight = weight_extraction_fn(mod)
|
||||
return {
|
||||
"type": res_type,
|
||||
|
@ -792,7 +792,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
||||
for i in range(1, len(path)):
|
||||
prev_key = path[i - 1]
|
||||
key = path[i]
|
||||
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else []
|
||||
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) is str else []
|
||||
|
||||
if isinstance(cur_container, Mapping):
|
||||
cur_container = cast(
|
||||
@ -806,7 +806,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
||||
cur_container = cur_container[prev_key]
|
||||
|
||||
key = path[-1]
|
||||
if type(key) == int:
|
||||
if type(key) is int:
|
||||
extend_list(cast(list[Any], cur_container), key)
|
||||
|
||||
cur_container[key] = value
|
||||
|
@ -121,7 +121,7 @@ def set_element(
|
||||
for i in range(1, len(path)):
|
||||
prev_key = path[i - 1]
|
||||
key = path[i]
|
||||
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
|
||||
def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else [])
|
||||
|
||||
if isinstance(cur_container, Mapping):
|
||||
cur_container = cast(
|
||||
@ -135,7 +135,7 @@ def set_element(
|
||||
cur_container = cur_container[prev_key]
|
||||
|
||||
key = path[-1]
|
||||
if type(key) == int:
|
||||
if type(key) is int:
|
||||
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
|
||||
|
||||
cur_container[key] = value
|
||||
|
@ -2203,7 +2203,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
# alive until all works and hooks are done. The current implementation does the
|
||||
# latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
|
||||
# for the pending hooks to finish.
|
||||
if type(pg) == ProcessGroup and pg._has_hooks():
|
||||
if type(pg) is ProcessGroup and pg._has_hooks():
|
||||
pg._wait_for_pending_works()
|
||||
|
||||
if group is None or group == GroupMember.WORLD:
|
||||
@ -2783,7 +2783,7 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
||||
key = "group_dst" if op.op == isend else "group_src"
|
||||
return {key: op.group_peer}
|
||||
|
||||
if type(group) == ProcessGroup and group._get_backend(device).supports_coalescing:
|
||||
if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing:
|
||||
# NCCL style coalescing
|
||||
with _coalescing_manager(group, device, async_ops=True) as cm:
|
||||
for p2p_op in p2p_op_list:
|
||||
|
@ -149,9 +149,9 @@ class EtcdStore(Store):
|
||||
# In case of `str`, utf-8 encoding is assumed.
|
||||
#
|
||||
def _encode(self, value) -> str:
|
||||
if type(value) == bytes:
|
||||
if type(value) is bytes:
|
||||
return b64encode(value).decode()
|
||||
elif type(value) == str:
|
||||
elif type(value) is str:
|
||||
return b64encode(value.encode()).decode()
|
||||
raise ValueError("Value must be of type str or bytes")
|
||||
|
||||
@ -160,9 +160,9 @@ class EtcdStore(Store):
|
||||
# Return type is `bytes`, which is more convenient with the Store interface.
|
||||
#
|
||||
def _decode(self, value) -> bytes:
|
||||
if type(value) == bytes:
|
||||
if type(value) is bytes:
|
||||
return b64decode(value)
|
||||
elif type(value) == str:
|
||||
elif type(value) is str:
|
||||
return b64decode(value.encode())
|
||||
raise ValueError("Value must be of type str or bytes")
|
||||
|
||||
|
@ -110,7 +110,7 @@ class DownSampling(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(m):
|
||||
if type(m) == nn.Conv2d or type(m) == nn.Linear:
|
||||
if type(m) is nn.Conv2d or type(m) is nn.Linear:
|
||||
nn.init.ones_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
@ -280,7 +280,7 @@ def _kl_exponential_exponential(p, q):
|
||||
|
||||
@register_kl(ExponentialFamily, ExponentialFamily)
|
||||
def _kl_expfamily_expfamily(p, q):
|
||||
if not type(p) == type(q):
|
||||
if type(p) is not type(q):
|
||||
raise NotImplementedError(
|
||||
"The cross KL-divergence between different exponential families cannot \
|
||||
be computed using Bregman divergences"
|
||||
|
@ -872,7 +872,7 @@ class AdditionalInputs:
|
||||
]
|
||||
|
||||
def _mark_dynamism(v, *other_vs):
|
||||
if not all(type(v) == type(other) for other in other_vs):
|
||||
if not all(type(v) is type(other) for other in other_vs):
|
||||
raise ValueError(
|
||||
"The following inputs were found to have differing types, "
|
||||
f"so they cannot be marked as dynamic: {(v,) + other_vs}."
|
||||
|
@ -927,7 +927,7 @@ class Tracer(TracerBase):
|
||||
|
||||
return out
|
||||
# Union[int, bool] == bool in Python <= 3.6
|
||||
if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor:
|
||||
if type(x) is bool or type(x) in base_types and type(x) != torch.Tensor:
|
||||
torch._assert(
|
||||
out == x,
|
||||
f"{name} has been specialized to have value {x} but got another value",
|
||||
|
@ -112,7 +112,7 @@ def unify_object(u, v, s):
|
||||
>>> unify_object(f, g, {})
|
||||
{~x: 2}
|
||||
"""
|
||||
if type(u) != type(v):
|
||||
if type(u) is not type(v):
|
||||
return False
|
||||
if hasattr(u, "__slots__"):
|
||||
return unify(
|
||||
|
@ -311,7 +311,7 @@ class ScriptMeta(type):
|
||||
original_init(self, *args, **kwargs)
|
||||
added_methods_in_init = len(cls._methods) > num_methods
|
||||
|
||||
if type(self) == cls:
|
||||
if type(self) is cls:
|
||||
|
||||
def make_stubs(module):
|
||||
cls = type(module)
|
||||
@ -804,7 +804,7 @@ if _enabled:
|
||||
|
||||
@property
|
||||
def original_name(self):
|
||||
if type(self) == str(self._c._type().name()):
|
||||
if type(self) is str(self._c._type().name()):
|
||||
return ""
|
||||
return str(self._c._type().name())
|
||||
|
||||
|
@ -651,7 +651,7 @@ def analyze_ts_result_with_export_result(export, trace):
|
||||
# mkldnn is not supported for torch.allclose
|
||||
if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
|
||||
return True
|
||||
if type(orig) != type(loaded):
|
||||
if type(orig) is not type(loaded):
|
||||
return False
|
||||
|
||||
if isinstance(orig, torch._subclasses.FakeTensor):
|
||||
|
@ -198,7 +198,7 @@ class MaskedTensor(torch.Tensor):
|
||||
def _validate_members(self):
|
||||
data = self._masked_data
|
||||
mask = self.get_mask()
|
||||
if type(data) != type(mask):
|
||||
if type(data) is not type(mask):
|
||||
raise TypeError(
|
||||
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
|
||||
)
|
||||
|
@ -1049,7 +1049,7 @@ class Module:
|
||||
>>> @torch.no_grad()
|
||||
>>> def init_weights(m):
|
||||
>>> print(m)
|
||||
>>> if type(m) == nn.Linear:
|
||||
>>> if type(m) is nn.Linear:
|
||||
>>> m.weight.fill_(1.0)
|
||||
>>> print(m.weight)
|
||||
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
|
||||
|
@ -10,7 +10,7 @@ def is_from_package(obj: Any) -> bool:
|
||||
|
||||
Note: packaged objects from externed modules will return ``False``.
|
||||
"""
|
||||
if type(obj) == ModuleType:
|
||||
if type(obj) is ModuleType:
|
||||
return is_mangled(obj.__name__)
|
||||
else:
|
||||
return is_mangled(type(obj).__module__)
|
||||
|
@ -891,7 +891,7 @@ class TypedStorage:
|
||||
def _new_wrapped_storage(self, untyped_storage) -> Self:
|
||||
assert type(untyped_storage) == torch.UntypedStorage
|
||||
|
||||
if type(self) == TypedStorage:
|
||||
if type(self) is TypedStorage:
|
||||
return cast(
|
||||
Self,
|
||||
TypedStorage(
|
||||
@ -913,7 +913,7 @@ class TypedStorage:
|
||||
return 0
|
||||
|
||||
else:
|
||||
if type(idx) != int:
|
||||
if type(idx) is not int:
|
||||
raise TypeError(f"can't index a {type(self)} with {type(idx)}")
|
||||
if is_stop:
|
||||
if (idx > self._size()) or (idx < -self._size()):
|
||||
@ -1513,7 +1513,7 @@ class _LegacyStorageMeta(type):
|
||||
dtype: torch.dtype
|
||||
|
||||
def __instancecheck__(cls, instance):
|
||||
if type(instance) == TypedStorage:
|
||||
if type(instance) is TypedStorage:
|
||||
cls_device = _get_device_from_module(cls.__module__)
|
||||
return (cls_device == instance.device.type) and (
|
||||
cls.dtype == instance.dtype
|
||||
|
@ -1432,7 +1432,7 @@ class MultiThreadedTestCase(TestCase):
|
||||
logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
|
||||
error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
|
||||
elif isinstance(exc, SystemExit):
|
||||
if type(exc.code) == int and skip_code < 0:
|
||||
if type(exc.code) is int and skip_code < 0:
|
||||
skip_code = exc.code
|
||||
|
||||
# check exceptions
|
||||
|
@ -1247,7 +1247,7 @@ class QuantizationTestCase(TestCase):
|
||||
}
|
||||
"""
|
||||
# TODO: make img_data a single example instead of a list
|
||||
if type(inputs) == list:
|
||||
if type(inputs) is list:
|
||||
inputs = inputs[0]
|
||||
|
||||
if quant_type == QuantType.QAT:
|
||||
|
@ -365,7 +365,7 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
|
||||
for mask in _generate_masked_op_mask(
|
||||
sample_input.input.shape, device, **kwargs
|
||||
):
|
||||
if type(mask) != torch.Tensor:
|
||||
if type(mask) is not torch.Tensor:
|
||||
continue
|
||||
sample_input_args, sample_input_kwargs = (
|
||||
sample_input.args,
|
||||
|
@ -70,7 +70,7 @@ class WrapperSubclass(torch.Tensor):
|
||||
def __coerce_same_metadata_as_tangent__(
|
||||
self, expected_metadata: Any, expected_type: Optional[type] = None
|
||||
):
|
||||
if expected_type == type(self.a):
|
||||
if expected_type is type(self.a):
|
||||
return self.a
|
||||
elif expected_type is TwoTensor:
|
||||
return TwoTensor(self.a, self.a.clone())
|
||||
|
@ -187,7 +187,7 @@ class GraphPy:
|
||||
)
|
||||
|
||||
for key, node in self.nodes_io.items():
|
||||
if type(node) == NodeBase:
|
||||
if type(node) is NodeBase:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
||||
if hasattr(node, "input_or_output"):
|
||||
|
Reference in New Issue
Block a user