[1/N] Use "is" in python type comparison (#165037)

It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
This commit is contained in:
Yuanyuan Chen
2025-10-10 12:36:46 +00:00
committed by PyTorch MergeBot
parent 960b0d5f0d
commit 70925bdf82
37 changed files with 51 additions and 51 deletions

View File

@ -2278,7 +2278,7 @@ class GuardBuilder(GuardBuilderBase):
# don't support this in serialization because it uses unsupported FUNCTION_MATCH
val = self.get(guard.name)
# Strictly only want user-defined functions
if type(val) == types.FunctionType and hasattr(val, "__code__"):
if type(val) is types.FunctionType and hasattr(val, "__code__"):
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
else:

View File

@ -197,7 +197,7 @@ class ConstDictVariable(VariableTracker):
@staticmethod
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
if type(a) != type(b):
if type(a) is not type(b):
return False
elif isinstance(a, tuple):
Hashable = ConstDictVariable._HashableTracker

View File

@ -3532,7 +3532,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
# check is important to avoid subclass dispatch
if type(value) != type(_local_map_wrapped):
if type(value) is not type(_local_map_wrapped):
return False
return value == _local_map_wrapped and cls._enabled

View File

@ -200,7 +200,7 @@ class BaseListVariable(VariableTracker):
if kwargs or len(args) != 1:
raise_args_mismatch(tx, name)
if type(self) != type(args[0]):
if type(self) is not type(args[0]):
tp_name = self.python_type_name()
other = args[0].python_type_name()
msg = ConstantVariable.create(

View File

@ -464,7 +464,7 @@ def _check_input_constraints_for_graph(
)
elif isinstance(node_val, (int, float, str)):
if type(arg) != type(node_val) or arg != node_val:
if type(arg) is not type(node_val) or arg != node_val:
raise RuntimeError(
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
)

View File

@ -361,7 +361,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
input_meta += get_input_meta(args[1])
return input_meta
for arg in args:
if type(arg) == int or type(arg) == float:
if type(arg) is int or type(arg) is float:
input_meta.append((type(arg),))
else:
input_meta.append(

View File

@ -802,7 +802,7 @@ def _insert_fused_matmul_reduce_scatter(
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
) -> torch.fx.Node:
if type(matmul) == _Matmul:
if type(matmul) is _Matmul:
return graph.call_function(
torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
args=(
@ -813,7 +813,7 @@ def _insert_fused_matmul_reduce_scatter(
group_name,
),
)
elif type(matmul) == _ScaledMatmul:
elif type(matmul) is _ScaledMatmul:
return graph.call_function(
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
args=(

View File

@ -98,7 +98,7 @@ class FakeTensorUpdater:
return statically_known_true(sym_eq(new, old))
def is_fake_tensor_same(new, old, *, node):
if type(new) != type(old):
if type(new) is not type(old):
return False
if isinstance(new, (list, tuple)):
if len(new) != len(old):

View File

@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode):
def parse_metadata(e):
if isinstance(e, torch.Tensor):
if type(e) != torch.Tensor:
if type(e) is not torch.Tensor:
try:
current = e.elem
return (

View File

@ -37,7 +37,7 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
if isinstance(dtype, str):
dtype = _import_dotted_name(dtype)
if dtype == type(self):
if dtype is type(self):
return self
if self.is_sparse:
if not dtype.is_sparse:

View File

@ -80,7 +80,7 @@ class ConvReLU1d(nnq.Conv1d):
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
@ -161,7 +161,7 @@ class ConvReLU2d(nnq.Conv2d):
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
@ -244,7 +244,7 @@ class ConvReLU3d(nnq.Conv3d):
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,

View File

@ -117,7 +117,7 @@ class Linear(nnq.Linear):
+ str([float_mod.__name__ for float_mod in float_modules])
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
if type(mod) is nni.LinearReLU:
mod = mod[0]
# pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None:

View File

@ -1064,15 +1064,15 @@ class RNNCellBase(torch.nn.Module):
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
if type(mod) == torch.nn.LSTMCell:
if type(mod) is torch.nn.LSTMCell:
qRNNCellBase = LSTMCell(
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
)
elif type(mod) == torch.nn.GRUCell:
elif type(mod) is torch.nn.GRUCell:
qRNNCellBase = GRUCell(
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
)
elif type(mod) == torch.nn.RNNCell:
elif type(mod) is torch.nn.RNNCell:
qRNNCellBase = RNNCell(
mod.input_size,
mod.hidden_size,

View File

@ -20,7 +20,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_BN_RELU_MODULE:
if type(mod) is cls._NNI_BN_RELU_MODULE:
mod = mod[0]
scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)

View File

@ -280,7 +280,7 @@ class _ConvNd(WeightedQuantizedModule):
if hasattr(mod, "weight_fake_quant"):
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
# ".from_float only works for " + cls.__QAT_MODULE.__name__
if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
if type(mod) is cls._NNIQAT_CONV_BN_MODULE:
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,

View File

@ -149,7 +149,7 @@ class Linear(torch.nn.Module):
# TODO: Need to add options to qconfig to avoid the calibration.
# TODO: Add calibration for the sparsity
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
if type(mod) is nni.LinearReLU:
mod = mod[0]
# pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None:

View File

@ -268,7 +268,7 @@ def extract_weight_from_node(
mod = getattr_from_fqn(gm, node.target)
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
for target_mod_type, weight_extraction_fn in module_mapping.items():
if type(mod) == target_mod_type:
if type(mod) is target_mod_type:
weight = weight_extraction_fn(mod)
return {
"type": res_type,

View File

@ -792,7 +792,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else []
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) is str else []
if isinstance(cur_container, Mapping):
cur_container = cast(
@ -806,7 +806,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
cur_container = cur_container[prev_key]
key = path[-1]
if type(key) == int:
if type(key) is int:
extend_list(cast(list[Any], cur_container), key)
cur_container[key] = value

View File

@ -121,7 +121,7 @@ def set_element(
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else [])
if isinstance(cur_container, Mapping):
cur_container = cast(
@ -135,7 +135,7 @@ def set_element(
cur_container = cur_container[prev_key]
key = path[-1]
if type(key) == int:
if type(key) is int:
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value

View File

@ -2203,7 +2203,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
# alive until all works and hooks are done. The current implementation does the
# latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
# for the pending hooks to finish.
if type(pg) == ProcessGroup and pg._has_hooks():
if type(pg) is ProcessGroup and pg._has_hooks():
pg._wait_for_pending_works()
if group is None or group == GroupMember.WORLD:
@ -2783,7 +2783,7 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
key = "group_dst" if op.op == isend else "group_src"
return {key: op.group_peer}
if type(group) == ProcessGroup and group._get_backend(device).supports_coalescing:
if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing:
# NCCL style coalescing
with _coalescing_manager(group, device, async_ops=True) as cm:
for p2p_op in p2p_op_list:

View File

@ -149,9 +149,9 @@ class EtcdStore(Store):
# In case of `str`, utf-8 encoding is assumed.
#
def _encode(self, value) -> str:
if type(value) == bytes:
if type(value) is bytes:
return b64encode(value).decode()
elif type(value) == str:
elif type(value) is str:
return b64encode(value.encode()).decode()
raise ValueError("Value must be of type str or bytes")
@ -160,9 +160,9 @@ class EtcdStore(Store):
# Return type is `bytes`, which is more convenient with the Store interface.
#
def _decode(self, value) -> bytes:
if type(value) == bytes:
if type(value) is bytes:
return b64decode(value)
elif type(value) == str:
elif type(value) is str:
return b64decode(value.encode())
raise ValueError("Value must be of type str or bytes")

View File

@ -110,7 +110,7 @@ class DownSampling(nn.Module):
@torch.no_grad()
def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.Linear:
if type(m) is nn.Conv2d or type(m) is nn.Linear:
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

View File

@ -280,7 +280,7 @@ def _kl_exponential_exponential(p, q):
@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
if not type(p) == type(q):
if type(p) is not type(q):
raise NotImplementedError(
"The cross KL-divergence between different exponential families cannot \
be computed using Bregman divergences"

View File

@ -872,7 +872,7 @@ class AdditionalInputs:
]
def _mark_dynamism(v, *other_vs):
if not all(type(v) == type(other) for other in other_vs):
if not all(type(v) is type(other) for other in other_vs):
raise ValueError(
"The following inputs were found to have differing types, "
f"so they cannot be marked as dynamic: {(v,) + other_vs}."

View File

@ -927,7 +927,7 @@ class Tracer(TracerBase):
return out
# Union[int, bool] == bool in Python <= 3.6
if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor:
if type(x) is bool or type(x) in base_types and type(x) != torch.Tensor:
torch._assert(
out == x,
f"{name} has been specialized to have value {x} but got another value",

View File

@ -112,7 +112,7 @@ def unify_object(u, v, s):
>>> unify_object(f, g, {})
{~x: 2}
"""
if type(u) != type(v):
if type(u) is not type(v):
return False
if hasattr(u, "__slots__"):
return unify(

View File

@ -311,7 +311,7 @@ class ScriptMeta(type):
original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls:
if type(self) is cls:
def make_stubs(module):
cls = type(module)
@ -804,7 +804,7 @@ if _enabled:
@property
def original_name(self):
if type(self) == str(self._c._type().name()):
if type(self) is str(self._c._type().name()):
return ""
return str(self._c._type().name())

View File

@ -651,7 +651,7 @@ def analyze_ts_result_with_export_result(export, trace):
# mkldnn is not supported for torch.allclose
if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
return True
if type(orig) != type(loaded):
if type(orig) is not type(loaded):
return False
if isinstance(orig, torch._subclasses.FakeTensor):

View File

@ -198,7 +198,7 @@ class MaskedTensor(torch.Tensor):
def _validate_members(self):
data = self._masked_data
mask = self.get_mask()
if type(data) != type(mask):
if type(data) is not type(mask):
raise TypeError(
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
)

View File

@ -1049,7 +1049,7 @@ class Module:
>>> @torch.no_grad()
>>> def init_weights(m):
>>> print(m)
>>> if type(m) == nn.Linear:
>>> if type(m) is nn.Linear:
>>> m.weight.fill_(1.0)
>>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

View File

@ -10,7 +10,7 @@ def is_from_package(obj: Any) -> bool:
Note: packaged objects from externed modules will return ``False``.
"""
if type(obj) == ModuleType:
if type(obj) is ModuleType:
return is_mangled(obj.__name__)
else:
return is_mangled(type(obj).__module__)

View File

@ -891,7 +891,7 @@ class TypedStorage:
def _new_wrapped_storage(self, untyped_storage) -> Self:
assert type(untyped_storage) == torch.UntypedStorage
if type(self) == TypedStorage:
if type(self) is TypedStorage:
return cast(
Self,
TypedStorage(
@ -913,7 +913,7 @@ class TypedStorage:
return 0
else:
if type(idx) != int:
if type(idx) is not int:
raise TypeError(f"can't index a {type(self)} with {type(idx)}")
if is_stop:
if (idx > self._size()) or (idx < -self._size()):
@ -1513,7 +1513,7 @@ class _LegacyStorageMeta(type):
dtype: torch.dtype
def __instancecheck__(cls, instance):
if type(instance) == TypedStorage:
if type(instance) is TypedStorage:
cls_device = _get_device_from_module(cls.__module__)
return (cls_device == instance.device.type) and (
cls.dtype == instance.dtype

View File

@ -1432,7 +1432,7 @@ class MultiThreadedTestCase(TestCase):
logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
elif isinstance(exc, SystemExit):
if type(exc.code) == int and skip_code < 0:
if type(exc.code) is int and skip_code < 0:
skip_code = exc.code
# check exceptions

View File

@ -1247,7 +1247,7 @@ class QuantizationTestCase(TestCase):
}
"""
# TODO: make img_data a single example instead of a list
if type(inputs) == list:
if type(inputs) is list:
inputs = inputs[0]
if quant_type == QuantType.QAT:

View File

@ -365,7 +365,7 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
if type(mask) != torch.Tensor:
if type(mask) is not torch.Tensor:
continue
sample_input_args, sample_input_kwargs = (
sample_input.args,

View File

@ -70,7 +70,7 @@ class WrapperSubclass(torch.Tensor):
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[type] = None
):
if expected_type == type(self.a):
if expected_type is type(self.a):
return self.a
elif expected_type is TwoTensor:
return TwoTensor(self.a, self.a.clone())

View File

@ -187,7 +187,7 @@ class GraphPy:
)
for key, node in self.nodes_io.items():
if type(node) == NodeBase:
if type(node) is NodeBase:
# pyrefly: ignore # unsupported-operation
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if hasattr(node, "input_or_output"):