diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index dc7eda8cfa53..622493070239 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2278,7 +2278,7 @@ class GuardBuilder(GuardBuilderBase): # don't support this in serialization because it uses unsupported FUNCTION_MATCH val = self.get(guard.name) # Strictly only want user-defined functions - if type(val) == types.FunctionType and hasattr(val, "__code__"): + if type(val) is types.FunctionType and hasattr(val, "__code__"): self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type] self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type] else: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 1691c4161889..36af33eaa944 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -197,7 +197,7 @@ class ConstDictVariable(VariableTracker): @staticmethod def _eq_impl(a, b): # TODO: Put this in utils and share it between variables/builtin.py and here - if type(a) != type(b): + if type(a) is not type(b): return False elif isinstance(a, tuple): Hashable = ConstDictVariable._HashableTracker diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 69cdbd450157..aa7792642f83 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3532,7 +3532,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): from torch.distributed.tensor.experimental._func_map import _local_map_wrapped # check is important to avoid subclass dispatch - if type(value) != type(_local_map_wrapped): + if type(value) is not type(_local_map_wrapped): return False return value == _local_map_wrapped and cls._enabled diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index f51ba102342c..c6b9434b6f05 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -200,7 +200,7 @@ class BaseListVariable(VariableTracker): if kwargs or len(args) != 1: raise_args_mismatch(tx, name) - if type(self) != type(args[0]): + if type(self) is not type(args[0]): tp_name = self.python_type_name() other = args[0].python_type_name() msg = ConstantVariable.create( diff --git a/torch/_export/utils.py b/torch/_export/utils.py index dfe3c8a09da2..bfe98daff2d5 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -464,7 +464,7 @@ def _check_input_constraints_for_graph( ) elif isinstance(node_val, (int, float, str)): - if type(arg) != type(node_val) or arg != node_val: + if type(arg) is not type(node_val) or arg != node_val: raise RuntimeError( f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", ) diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 303281f85608..8070e47153ca 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -361,7 +361,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_ input_meta += get_input_meta(args[1]) return input_meta for arg in args: - if type(arg) == int or type(arg) == float: + if type(arg) is int or type(arg) is float: input_meta.append((type(arg),)) else: input_meta.append( diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 635be342a2f0..4a4b3456f4a3 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -802,7 +802,7 @@ def _insert_fused_matmul_reduce_scatter( scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern ) -> torch.fx.Node: - if type(matmul) == _Matmul: + if type(matmul) is _Matmul: return graph.call_function( torch.ops.symm_mem.fused_matmul_reduce_scatter.default, args=( @@ -813,7 +813,7 @@ def _insert_fused_matmul_reduce_scatter( group_name, ), ) - elif type(matmul) == _ScaledMatmul: + elif type(matmul) is _ScaledMatmul: return graph.call_function( torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, args=( diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index fdc60e19efb6..4c0a2ff35e18 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -98,7 +98,7 @@ class FakeTensorUpdater: return statically_known_true(sym_eq(new, old)) def is_fake_tensor_same(new, old, *, node): - if type(new) != type(old): + if type(new) is not type(old): return False if isinstance(new, (list, tuple)): if len(new) != len(old): diff --git a/torch/_subclasses/schema_check_mode.py b/torch/_subclasses/schema_check_mode.py index a2165c8945d1..cf49f1e212d8 100644 --- a/torch/_subclasses/schema_check_mode.py +++ b/torch/_subclasses/schema_check_mode.py @@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode): def parse_metadata(e): if isinstance(e, torch.Tensor): - if type(e) != torch.Tensor: + if type(e) is not torch.Tensor: try: current = e.elem return ( diff --git a/torch/_utils.py b/torch/_utils.py index bf431553abc8..c7b63525073a 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -37,7 +37,7 @@ def _type(self, dtype=None, non_blocking=False, **kwargs): if isinstance(dtype, str): dtype = _import_dotted_name(dtype) - if dtype == type(self): + if dtype is type(self): return self if self.is_sparse: if not dtype.is_sparse: diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py index c8024c5b4c58..37af49f1701b 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -80,7 +80,7 @@ class ConvReLU1d(nnq.Conv1d): @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] - if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: + if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, @@ -161,7 +161,7 @@ class ConvReLU2d(nnq.Conv2d): @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] - if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: + if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, @@ -244,7 +244,7 @@ class ConvReLU3d(nnq.Conv3d): @classmethod def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] - if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: + if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d: assert mod.bn.running_var is not None and mod.bn.running_mean is not None mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, diff --git a/torch/ao/nn/quantized/dynamic/modules/linear.py b/torch/ao/nn/quantized/dynamic/modules/linear.py index 6fa5ee65f3b5..2ea3fc972046 100644 --- a/torch/ao/nn/quantized/dynamic/modules/linear.py +++ b/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -117,7 +117,7 @@ class Linear(nnq.Linear): + str([float_mod.__name__ for float_mod in float_modules]) ) assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" - if type(mod) == nni.LinearReLU: + if type(mod) is nni.LinearReLU: mod = mod[0] # pyrefly: ignore # missing-attribute if mod.qconfig is not None and mod.qconfig.weight is not None: diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index cdfecc95d723..fb5371ea4a4f 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -1064,15 +1064,15 @@ class RNNCellBase(torch.nn.Module): qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell] - if type(mod) == torch.nn.LSTMCell: + if type(mod) is torch.nn.LSTMCell: qRNNCellBase = LSTMCell( mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype ) - elif type(mod) == torch.nn.GRUCell: + elif type(mod) is torch.nn.GRUCell: qRNNCellBase = GRUCell( mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype ) - elif type(mod) == torch.nn.RNNCell: + elif type(mod) is torch.nn.RNNCell: qRNNCellBase = RNNCell( mod.input_size, mod.hidden_size, diff --git a/torch/ao/nn/quantized/modules/batchnorm.py b/torch/ao/nn/quantized/modules/batchnorm.py index bd426038657c..782bfdbda283 100644 --- a/torch/ao/nn/quantized/modules/batchnorm.py +++ b/torch/ao/nn/quantized/modules/batchnorm.py @@ -20,7 +20,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm): @staticmethod def from_float(cls, mod, use_precomputed_fake_quant=False): activation_post_process = mod.activation_post_process - if type(mod) == cls._NNI_BN_RELU_MODULE: + if type(mod) is cls._NNI_BN_RELU_MODULE: mod = mod[0] scale, zero_point = activation_post_process.calculate_qparams() new_mod = cls(mod.num_features, mod.eps) diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 1bec74975f8a..e64964ba8e8f 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -280,7 +280,7 @@ class _ConvNd(WeightedQuantizedModule): if hasattr(mod, "weight_fake_quant"): # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ # ".from_float only works for " + cls.__QAT_MODULE.__name__ - if type(mod) == cls._NNIQAT_CONV_BN_MODULE: + if type(mod) is cls._NNIQAT_CONV_BN_MODULE: mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, diff --git a/torch/ao/nn/sparse/quantized/dynamic/linear.py b/torch/ao/nn/sparse/quantized/dynamic/linear.py index 5ae9a9227dbb..835ee0f90631 100644 --- a/torch/ao/nn/sparse/quantized/dynamic/linear.py +++ b/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -149,7 +149,7 @@ class Linear(torch.nn.Module): # TODO: Need to add options to qconfig to avoid the calibration. # TODO: Add calibration for the sparsity assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" - if type(mod) == nni.LinearReLU: + if type(mod) is nni.LinearReLU: mod = mod[0] # pyrefly: ignore # missing-attribute if mod.qconfig is not None and mod.qconfig.weight is not None: diff --git a/torch/ao/ns/fx/weight_utils.py b/torch/ao/ns/fx/weight_utils.py index 6a2eebc6e0b7..1b665e616a1e 100644 --- a/torch/ao/ns/fx/weight_utils.py +++ b/torch/ao/ns/fx/weight_utils.py @@ -268,7 +268,7 @@ def extract_weight_from_node( mod = getattr_from_fqn(gm, node.target) module_mapping = op_to_type_to_weight_extraction_fn["call_module"] for target_mod_type, weight_extraction_fn in module_mapping.items(): - if type(mod) == target_mod_type: + if type(mod) is target_mod_type: weight = weight_extraction_fn(mod) return { "type": res_type, diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index e17009a48af3..06aa9db81e9c 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -792,7 +792,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None for i in range(1, len(path)): prev_key = path[i - 1] key = path[i] - def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else [] + def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) is str else [] if isinstance(cur_container, Mapping): cur_container = cast( @@ -806,7 +806,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None cur_container = cur_container[prev_key] key = path[-1] - if type(key) == int: + if type(key) is int: extend_list(cast(list[Any], cur_container), key) cur_container[key] = value diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 9bde4f47c329..cfd605a2bfb4 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -121,7 +121,7 @@ def set_element( for i in range(1, len(path)): prev_key = path[i - 1] key = path[i] - def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) + def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else []) if isinstance(cur_container, Mapping): cur_container = cast( @@ -135,7 +135,7 @@ def set_element( cur_container = cur_container[prev_key] key = path[-1] - if type(key) == int: + if type(key) is int: extend_list(cast(list[STATE_DICT_ITEM], cur_container), key) cur_container[key] = value diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 213c11baeccb..ea194a6ebe9a 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2203,7 +2203,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): # alive until all works and hooks are done. The current implementation does the # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait # for the pending hooks to finish. - if type(pg) == ProcessGroup and pg._has_hooks(): + if type(pg) is ProcessGroup and pg._has_hooks(): pg._wait_for_pending_works() if group is None or group == GroupMember.WORLD: @@ -2783,7 +2783,7 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: key = "group_dst" if op.op == isend else "group_src" return {key: op.group_peer} - if type(group) == ProcessGroup and group._get_backend(device).supports_coalescing: + if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing: # NCCL style coalescing with _coalescing_manager(group, device, async_ops=True) as cm: for p2p_op in p2p_op_list: diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index 676303216f11..781a40e20e91 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -149,9 +149,9 @@ class EtcdStore(Store): # In case of `str`, utf-8 encoding is assumed. # def _encode(self, value) -> str: - if type(value) == bytes: + if type(value) is bytes: return b64encode(value).decode() - elif type(value) == str: + elif type(value) is str: return b64encode(value.encode()).decode() raise ValueError("Value must be of type str or bytes") @@ -160,9 +160,9 @@ class EtcdStore(Store): # Return type is `bytes`, which is more convenient with the Store interface. # def _decode(self, value) -> bytes: - if type(value) == bytes: + if type(value) is bytes: return b64decode(value) - elif type(value) == str: + elif type(value) is str: return b64decode(value.encode()) raise ValueError("Value must be of type str or bytes") diff --git a/torch/distributed/tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py index d81429a49dc8..c1bd542922af 100644 --- a/torch/distributed/tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -110,7 +110,7 @@ class DownSampling(nn.Module): @torch.no_grad() def init_weights(m): - if type(m) == nn.Conv2d or type(m) == nn.Linear: + if type(m) is nn.Conv2d or type(m) is nn.Linear: nn.init.ones_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 19d7befe6afd..ca82802bcc85 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -280,7 +280,7 @@ def _kl_exponential_exponential(p, q): @register_kl(ExponentialFamily, ExponentialFamily) def _kl_expfamily_expfamily(p, q): - if not type(p) == type(q): + if type(p) is not type(q): raise NotImplementedError( "The cross KL-divergence between different exponential families cannot \ be computed using Bregman divergences" diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 375d059d64cb..e362e8334241 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -872,7 +872,7 @@ class AdditionalInputs: ] def _mark_dynamism(v, *other_vs): - if not all(type(v) == type(other) for other in other_vs): + if not all(type(v) is type(other) for other in other_vs): raise ValueError( "The following inputs were found to have differing types, " f"so they cannot be marked as dynamic: {(v,) + other_vs}." diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 6e67fa56d168..3905266e24cd 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -927,7 +927,7 @@ class Tracer(TracerBase): return out # Union[int, bool] == bool in Python <= 3.6 - if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: + if type(x) is bool or type(x) in base_types and type(x) != torch.Tensor: torch._assert( out == x, f"{name} has been specialized to have value {x} but got another value", diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index 8a00065dc0c3..f1df562a2dcd 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -112,7 +112,7 @@ def unify_object(u, v, s): >>> unify_object(f, g, {}) {~x: 2} """ - if type(u) != type(v): + if type(u) is not type(v): return False if hasattr(u, "__slots__"): return unify( diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 4f709328a913..86b72d1d4656 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -311,7 +311,7 @@ class ScriptMeta(type): original_init(self, *args, **kwargs) added_methods_in_init = len(cls._methods) > num_methods - if type(self) == cls: + if type(self) is cls: def make_stubs(module): cls = type(module) @@ -804,7 +804,7 @@ if _enabled: @property def original_name(self): - if type(self) == str(self._c._type().name()): + if type(self) is str(self._c._type().name()): return "" return str(self._c._type().name()) diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index a26f4dc1bfb3..5b1713e77d36 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -651,7 +651,7 @@ def analyze_ts_result_with_export_result(export, trace): # mkldnn is not supported for torch.allclose if orig.layout == torch._mkldnn: # type: ignore[attr-defined] return True - if type(orig) != type(loaded): + if type(orig) is not type(loaded): return False if isinstance(orig, torch._subclasses.FakeTensor): diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 11e804d97854..0b3fa9b858fe 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -198,7 +198,7 @@ class MaskedTensor(torch.Tensor): def _validate_members(self): data = self._masked_data mask = self.get_mask() - if type(data) != type(mask): + if type(data) is not type(mask): raise TypeError( f"data and mask must have the same type. Got {type(data)} and {type(mask)}" ) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 118d75f3b8d6..084e98217819 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1049,7 +1049,7 @@ class Module: >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) - >>> if type(m) == nn.Linear: + >>> if type(m) is nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) diff --git a/torch/package/analyze/is_from_package.py b/torch/package/analyze/is_from_package.py index 82ff5896b6ff..800f87eb4867 100644 --- a/torch/package/analyze/is_from_package.py +++ b/torch/package/analyze/is_from_package.py @@ -10,7 +10,7 @@ def is_from_package(obj: Any) -> bool: Note: packaged objects from externed modules will return ``False``. """ - if type(obj) == ModuleType: + if type(obj) is ModuleType: return is_mangled(obj.__name__) else: return is_mangled(type(obj).__module__) diff --git a/torch/storage.py b/torch/storage.py index 5fc60055cd71..e6971f52cf22 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -891,7 +891,7 @@ class TypedStorage: def _new_wrapped_storage(self, untyped_storage) -> Self: assert type(untyped_storage) == torch.UntypedStorage - if type(self) == TypedStorage: + if type(self) is TypedStorage: return cast( Self, TypedStorage( @@ -913,7 +913,7 @@ class TypedStorage: return 0 else: - if type(idx) != int: + if type(idx) is not int: raise TypeError(f"can't index a {type(self)} with {type(idx)}") if is_stop: if (idx > self._size()) or (idx < -self._size()): @@ -1513,7 +1513,7 @@ class _LegacyStorageMeta(type): dtype: torch.dtype def __instancecheck__(cls, instance): - if type(instance) == TypedStorage: + if type(instance) is TypedStorage: cls_device = _get_device_from_module(cls.__module__) return (cls_device == instance.device.type) and ( cls.dtype == instance.dtype diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 3ae1b81ef16e..17a317463cb5 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1432,7 +1432,7 @@ class MultiThreadedTestCase(TestCase): logger.error("Caught exception: \n%s exiting thread %s", msg, rank) error_msg += f"Thread {rank} exited with exception:\n{msg}\n" elif isinstance(exc, SystemExit): - if type(exc.code) == int and skip_code < 0: + if type(exc.code) is int and skip_code < 0: skip_code = exc.code # check exceptions diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index f0087854534c..fde4f396b2b9 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -1247,7 +1247,7 @@ class QuantizationTestCase(TestCase): } """ # TODO: make img_data a single example instead of a list - if type(inputs) == list: + if type(inputs) is list: inputs = inputs[0] if quant_type == QuantType.QAT: diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index c5d08073803b..4ff16b343715 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -365,7 +365,7 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs) for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): - if type(mask) != torch.Tensor: + if type(mask) is not torch.Tensor: continue sample_input_args, sample_input_kwargs = ( sample_input.args, diff --git a/torch/testing/_internal/subclasses.py b/torch/testing/_internal/subclasses.py index 0898c288d926..228f98139fea 100644 --- a/torch/testing/_internal/subclasses.py +++ b/torch/testing/_internal/subclasses.py @@ -70,7 +70,7 @@ class WrapperSubclass(torch.Tensor): def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[type] = None ): - if expected_type == type(self.a): + if expected_type is type(self.a): return self.a elif expected_type is TwoTensor: return TwoTensor(self.a, self.a.clone()) diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index 1577516b3f6d..31ae14919315 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -187,7 +187,7 @@ class GraphPy: ) for key, node in self.nodes_io.items(): - if type(node) == NodeBase: + if type(node) is NodeBase: # pyrefly: ignore # unsupported-operation self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName if hasattr(node, "input_or_output"):