From 8de85896e05df8f992e09a302eac5cca9b2038a9 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 13 Oct 2025 01:48:55 +0000 Subject: [PATCH] Enable ruff rule E721 (#165162) `E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162 Approved by: https://github.com/Skylion007 --- .../torchaudio_models.py | 2 +- benchmarks/instruction_counts/core/api.py | 2 +- .../operator_benchmark/benchmark_pytorch.py | 2 +- ...i_operator_benchmark_eager_float32_cpu.csv | 2 +- benchmarks/operator_benchmark/pt/cat_test.py | 2 +- .../operator_benchmark/pt/stack_test.py | 2 +- pyproject.toml | 1 - .../ao/sparsity/test_activation_sparsifier.py | 2 +- test/ao/sparsity/test_data_sparsifier.py | 30 ++++++++++--------- test/ao/sparsity/test_sparsifier.py | 4 +-- .../ao/sparsity/test_structured_sparsifier.py | 2 +- .../checkpoint/test_state_dict_stager.py | 6 ++-- test/distributed/fsdp/test_fsdp_apply.py | 4 +-- test/distributed/fsdp/test_fsdp_misc.py | 2 +- .../distributed/fsdp/test_fsdp_optim_state.py | 2 +- test/distributions/test_distributions.py | 2 +- test/dynamo/test_misc.py | 4 +-- test/dynamo/test_sources.py | 2 +- test/dynamo/test_subclasses.py | 2 +- test/export/opinfo_schema.py | 2 +- test/export/test_nativert.py | 4 +-- test/export/test_serialize.py | 2 +- test/functorch/test_aotdispatch.py | 2 +- test/functorch/test_control_flow.py | 2 +- test/fx/test_fx_split.py | 2 +- test/fx/test_subgraph_rewriter.py | 4 +-- test/inductor/test_binary_folding.py | 8 ++--- test/inductor/test_cache.py | 10 +++---- test/inductor/test_cutlass_backend.py | 2 +- test/inductor/test_efficient_conv_bn_eval.py | 6 ++-- test/inductor/test_torchinductor.py | 4 +-- test/inductor/test_utils.py | 2 +- test/jit/test_freezing.py | 28 ++++++++--------- test/jit/test_typing.py | 2 +- test/nn/test_convolution.py | 4 +-- test/nn/test_load_state_dict.py | 4 +-- test/quantization/core/test_quantized_op.py | 2 +- .../quantization/core/test_workflow_module.py | 4 +-- test/quantization/core/test_workflow_ops.py | 6 ++-- .../eager/test_quantize_eager_qat.py | 6 ++-- test/quantization/fx/test_model_report_fx.py | 2 +- test/quantization/fx/test_quantize_fx.py | 4 +-- .../quantization/fx/test_subgraph_rewriter.py | 4 +-- .../pt2e/test_x86inductor_quantizer.py | 2 +- test/test_binary_ufuncs.py | 8 ++--- test/test_datapipe.py | 6 ++-- test/test_decomp.py | 4 +-- test/test_jit.py | 12 ++++---- test/test_multiprocessing.py | 4 +-- test/test_numpy_interop.py | 2 +- test/test_reductions.py | 2 +- test/test_type_promotion.py | 4 +-- .../torch_np/numpy_tests/core/test_numeric.py | 2 +- .../numpy_tests/core/test_scalarmath.py | 8 ++--- .../numpy_tests/linalg/test_linalg.py | 8 ++--- test/torch_np/test_ndarray_methods.py | 5 ++-- test/torch_np/test_nep50_examples.py | 2 +- tools/experimental/torchfuzz/tensor_fuzzer.py | 2 +- torch/_decomp/decompositions.py | 2 +- torch/_dynamo/codegen.py | 2 +- torch/_dynamo/guards.py | 2 +- torch/_dynamo/variables/tensor.py | 2 +- torch/_export/serde/schema_check.py | 6 ++-- torch/_higher_order_ops/partitioner.py | 2 +- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/fuzzer.py | 10 +++---- torch/_logging/_internal.py | 2 +- torch/_numpy/_reductions_impl.py | 2 +- torch/_refs/__init__.py | 2 +- torch/_utils.py | 2 +- torch/ao/ns/fx/utils.py | 2 +- .../fx/_lower_to_native_backend.py | 10 +++---- .../_model_report/model_report_visualizer.py | 2 +- torch/ao/quantization/fx/utils.py | 6 ++-- .../fsdp/fully_sharded_data_parallel.py | 4 +-- .../experimental/graph_gradual_typechecker.py | 4 +-- torch/fx/passes/reinplace.py | 2 +- torch/utils/data/datapipes/_typing.py | 2 +- 78 files changed, 166 insertions(+), 164 deletions(-) diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py index 19fa23e55413..5a26616cb507 100644 --- a/benchmarks/functional_autograd_benchmark/torchaudio_models.py +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -367,7 +367,7 @@ class DeepSpeech(nn.Module): """ seq_len = input_length for m in self.conv.modules(): - if type(m) == nn.modules.conv.Conv2d: + if type(m) is nn.modules.conv.Conv2d: seq_len = ( seq_len + 2 * m.padding[1] diff --git a/benchmarks/instruction_counts/core/api.py b/benchmarks/instruction_counts/core/api.py index 7d0b1a0f72ea..d22fc5a66fab 100644 --- a/benchmarks/instruction_counts/core/api.py +++ b/benchmarks/instruction_counts/core/api.py @@ -66,7 +66,7 @@ class GroupedSetup: def __post_init__(self) -> None: for field in dataclasses.fields(self): - assert field.type == str + assert field.type is str value: str = getattr(self, field.name) object.__setattr__(self, field.name, textwrap.dedent(value)) diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index cfed9ebac04b..fa022417da45 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -113,7 +113,7 @@ class TorchBenchmarkBase(torch.nn.Module): value = kargs[key] test_name_str.append( ("" if key in skip_key_list else key) - + str(value if type(value) != bool else int(value)) + + str(value if type(value) is not bool else int(value)) ) name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "") return name diff --git a/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv b/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv index 9a7b6797e982..3c5a090376ed 100644 --- a/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv +++ b/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv @@ -1158,7 +1158,7 @@ PyTorch,q_argsort,q_argsort_M512_N512_dtypetorch.quint8,short,FALSE,446.4263 PyTorch,q_clone,q_clone_M512_N512_dtypetorch.quint8,short,FALSE,10.9374 PyTorch,q_mean,q_mean_M512_N512_dtypetorch.quint8,short,FALSE,10.2288 PyTorch,q_relu,q_relu_M512_N512_dtypetorch.quint8,short,FALSE,10.3366 -PyTorch,q_relu_,q_relu__M512_N512_dtypetorch.quint8,short,FALSE,25.3594 +PyTorch,q_relu_,q_relu__M512_N512_dtypetorch.quint8,short,FALSE,7.9869 PyTorch,q_sort,q_sort_M512_N512_dtypetorch.quint8,short,FALSE,447.1303 PyTorch,qtopk,qtopk_M512_N512_k5_dtypetorch.quint8,short,FALSE,64.856 PyTorch,abs,abs_M512_N512_cpu,short,FALSE,12.3046 diff --git a/benchmarks/operator_benchmark/pt/cat_test.py b/benchmarks/operator_benchmark/pt/cat_test.py index c0dc08593a9c..cf0369a43345 100644 --- a/benchmarks/operator_benchmark/pt/cat_test.py +++ b/benchmarks/operator_benchmark/pt/cat_test.py @@ -125,7 +125,7 @@ class CatBenchmark(op_bench.TorchBenchmarkBase): random.seed(42) inputs = [] gen_sizes = [] - if type(sizes) == list and N == -1: + if type(sizes) is list and N == -1: gen_sizes = sizes else: for i in range(N): diff --git a/benchmarks/operator_benchmark/pt/stack_test.py b/benchmarks/operator_benchmark/pt/stack_test.py index 9e1e25be1f4e..5dea1d9ca1ef 100644 --- a/benchmarks/operator_benchmark/pt/stack_test.py +++ b/benchmarks/operator_benchmark/pt/stack_test.py @@ -61,7 +61,7 @@ class StackBenchmark(op_bench.TorchBenchmarkBase): random.seed(42) inputs = [] gen_sizes = [] - if type(sizes) == list and N == -1: + if type(sizes) is list and N == -1: gen_sizes = sizes else: for i in range(N): diff --git a/pyproject.toml b/pyproject.toml index 8a2823258916..f75261ba6ffb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,6 @@ ignore = [ "E402", "C408", # C408 ignored because we like the dict keyword argument syntax "E501", # E501 is not flexible enough, we're using B950 instead - "E721", "E741", "EXE001", "F405", diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py index 923ffa16fa02..122c368368e6 100644 --- a/test/ao/sparsity/test_activation_sparsifier.py +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -243,7 +243,7 @@ class TestActivationSparsifier(TestCase): if mask1 is None: assert mask2 is None else: - assert type(mask1) == type(mask2) + assert type(mask1) is type(mask2) if isinstance(mask1, list): assert len(mask1) == len(mask2) for idx in range(len(mask1)): diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index c333138769a4..dce04292763f 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -710,15 +710,15 @@ class TestQuantizationUtils(TestCase): **sparse_config, ) - assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding + assert type(model.emb1) is torch.ao.nn.quantized.modules.embedding_ops.Embedding assert ( type(model.embbag1) - == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag + is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag ) - assert type(model.emb_seq[0] == nn.Embedding) - assert type(model.emb_seq[1] == nn.EmbeddingBag) - assert type(model.linear1) == nn.Linear - assert type(model.linear2) == nn.Linear + assert type(model.emb_seq[0] is nn.Embedding) + assert type(model.emb_seq[1] is nn.EmbeddingBag) + assert type(model.linear1) is nn.Linear + assert type(model.linear2) is nn.Linear dequant_emb1 = torch.dequantize(model.emb1.weight()) dequant_embbag1 = torch.dequantize(model.embbag1.weight()) @@ -749,19 +749,21 @@ class TestQuantizationUtils(TestCase): model, DataNormSparsifier, sparsify_first=False, **sparse_config ) - assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding + assert type(model.emb1) is torch.ao.nn.quantized.modules.embedding_ops.Embedding assert ( type(model.embbag1) - == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag + is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag ) - assert type( - model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding + assert ( + type(model.emb_seq[0]) + is torch.ao.nn.quantized.modules.embedding_ops.Embedding ) - assert type( - model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag + assert ( + type(model.emb_seq[1]) + is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag ) - assert type(model.linear1) == nn.Linear # not quantized - assert type(model.linear2) == nn.Linear # not quantized + assert type(model.linear1) is nn.Linear # not quantized + assert type(model.linear2) is nn.Linear # not quantized dequant_emb1 = torch.dequantize(model.emb1.weight()) dequant_embbag1 = torch.dequantize(model.embbag1.weight()) diff --git a/test/ao/sparsity/test_sparsifier.py b/test/ao/sparsity/test_sparsifier.py index 86e26e5ca11e..d5010b7abccd 100644 --- a/test/ao/sparsity/test_sparsifier.py +++ b/test/ao/sparsity/test_sparsifier.py @@ -291,7 +291,7 @@ class TestWeightNormSparsifier(TestCase): assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") - assert type(module.parametrizations.weight[0]) == FakeSparsity + assert type(module.parametrizations.weight[0]) is FakeSparsity def test_mask_squash(self): model = SimpleLinear() @@ -415,7 +415,7 @@ class TestNearlyDiagonalSparsifier(TestCase): assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") - assert type(module.parametrizations.weight[0]) == FakeSparsity + assert type(module.parametrizations.weight[0]) is FakeSparsity def test_mask_squash(self): model = SimpleLinear() diff --git a/test/ao/sparsity/test_structured_sparsifier.py b/test/ao/sparsity/test_structured_sparsifier.py index 812490452767..4ed9bea7d0f7 100644 --- a/test/ao/sparsity/test_structured_sparsifier.py +++ b/test/ao/sparsity/test_structured_sparsifier.py @@ -158,7 +158,7 @@ class TestBaseStructuredSparsifier(TestCase): assert parametrize.is_parametrized(module) assert hasattr(module, "parametrizations") # Assume that this is the 1st/only parametrization - assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity + assert type(module.parametrizations.weight[0]) is FakeStructuredSparsity def _check_pruner_valid_before_step(self, model, pruner, device): for config in pruner.groups: diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py index a08a8f5eec90..22cb2f32cf4a 100644 --- a/test/distributed/checkpoint/test_state_dict_stager.py +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -134,7 +134,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8): False, f"Collection length mismatch at {path}: {len(gpu_obj)} vs {len(cpu_obj)}", ) - if type(gpu_obj) != type(cpu_obj): + if type(gpu_obj) is not type(cpu_obj): return ( False, f"Collection type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}", @@ -149,7 +149,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8): # If objects are custom classes, compare their attributes elif hasattr(gpu_obj, "__dict__") and hasattr(cpu_obj, "__dict__"): - if type(gpu_obj) != type(cpu_obj): + if type(gpu_obj) is not type(cpu_obj): return ( False, f"Object type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}", @@ -165,7 +165,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8): # For other types, use direct equality comparison else: - if type(gpu_obj) != type(cpu_obj): + if type(gpu_obj) is not type(cpu_obj): return ( False, f"Type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}", diff --git a/test/distributed/fsdp/test_fsdp_apply.py b/test/distributed/fsdp/test_fsdp_apply.py index d56ac09ebe5a..c0f1a791c534 100644 --- a/test/distributed/fsdp/test_fsdp_apply.py +++ b/test/distributed/fsdp/test_fsdp_apply.py @@ -44,14 +44,14 @@ class TestApply(FSDPTest): @torch.no_grad() def _init_linear_weights(self, m): - if type(m) == nn.Linear: + if type(m) is nn.Linear: m.weight.fill_(1.0) m.bias.fill_(1.0) def check_weights(self, fsdp, expected_tensor_fn, check): with FSDP.summon_full_params(fsdp, recurse=True): linear_modules = [ - module for module in fsdp.modules() if type(module) == nn.Linear + module for module in fsdp.modules() if type(module) is nn.Linear ] for module in linear_modules: for param in module.parameters(): diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 45c1668dfb2e..2ae986af785b 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -1021,7 +1021,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread): ) for warning in w: self.assertTrue( - warning.category != UserWarning + warning.category is not UserWarning or not str(warning.message).startswith(warning_prefix) ) diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index 4db192ed7c34..99e5db33d67d 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -421,7 +421,7 @@ class TestFSDPOptimState(FSDPTest): return False for state_name, value1 in state1.items(): value2 = state2[state_name] - if type(value1) != type(value2): + if type(value1) is not type(value2): return False if torch.is_tensor(value1): # tensor state assert torch.is_tensor(value2) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index aaae775f191c..b588589d81ba 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -5887,7 +5887,7 @@ class TestKL(DistributionsTestCase): def test_kl_exponential_family(self): for (p, _), (_, q) in self.finite_examples: - if type(p) == type(q) and issubclass(type(p), ExponentialFamily): + if type(p) is type(q) and issubclass(type(p), ExponentialFamily): actual = kl_divergence(p, q) expected = _kl_expfamily_expfamily(p, q) self.assertEqual( diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index c625db6bf2d6..a41d5851a8ed 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -3370,9 +3370,9 @@ utils_device.CURRENT_DEVICE == None""".split("\n"): # Test on non autocast state and autocast cache states. self.assertIn("autocast_state", json_guards) for key, value in json_guards.items(): - if type(value) == int: + if type(value) is int: variant = value + 1 - elif type(value) == bool: + elif type(value) is bool: variant = not value elif isinstance(value, dict) and key == "autocast_state": variant = value.copy() diff --git a/test/dynamo/test_sources.py b/test/dynamo/test_sources.py index 5b16e00270b0..a2f91afc93b7 100644 --- a/test/dynamo/test_sources.py +++ b/test/dynamo/test_sources.py @@ -59,7 +59,7 @@ class SourceTests(torch._dynamo.test_case.TestCase): def forward(self): if ( torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type - == int + is int ): x = torch.sin(self.x) else: diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ec67ef5eb8c3..0242badeb99e 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -662,7 +662,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase): "comparison", [ subtest(isinstance, "isinstance"), - subtest(lambda instance, type_: type(instance) == type_, "equality"), + subtest(lambda instance, type_: type(instance) is type_, "equality"), subtest(lambda instance, type_: type(instance) is type_, "identity"), ], ) diff --git a/test/export/opinfo_schema.py b/test/export/opinfo_schema.py index 837213659847..292d06fc04d8 100644 --- a/test/export/opinfo_schema.py +++ b/test/export/opinfo_schema.py @@ -38,7 +38,7 @@ class PreDispatchSchemaCheckMode(SchemaCheckMode): def _may_alias_or_mutate(self, func, types, args, kwargs): def unwrap(e): - if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: + if isinstance(e, torch.Tensor) and type(e) is not torch.Tensor: try: return e.elem except AttributeError: diff --git a/test/export/test_nativert.py b/test/export/test_nativert.py index 20c5d1ca562c..20f61ad03fff 100644 --- a/test/export/test_nativert.py +++ b/test/export/test_nativert.py @@ -128,7 +128,7 @@ def run_with_nativert(ep): flat_results = pytree.tree_leaves(results) assert len(flat_results) == len(flat_expected) for result, expected in zip(flat_results, flat_expected): - assert type(result) == type(expected) + assert type(result) is type(expected) if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor): assert result.shape == expected.shape assert result.dtype == expected.dtype @@ -323,7 +323,7 @@ class TestNativeRT(TestCase): flat_results = pytree.tree_leaves(results) assert len(flat_results) == len(flat_expected) for result, expected in zip(flat_results, flat_expected): - assert type(result) == type(expected) + assert type(result) is type(expected) if isinstance(result, torch.Tensor) and isinstance( expected, torch.Tensor ): diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 275e699cb6b3..0e1eb0140bbb 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -82,7 +82,7 @@ class TestSerialize(TestCase): return 0 def __eq__(self, other): - return type(other) == type(self) + return type(other) is type(self) def __call__(self, *args, **kwargs): return torch.ops.aten.add.Tensor(*args, **kwargs) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 41b37a687fae..404279b5c4dd 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6332,7 +6332,7 @@ def forward(self, tangents_1, tangents_2): self.assertEqual(out_ref[0].b, out_test[0].b) self.assertEqual(out_ref[1], out_test[1]) - # We compiled our graph assuming type(grad_out[1]) == torch.Tensor, + # We compiled our graph assuming type(grad_out[1]) is torch.Tensor, # but we were wrong: in the below tests, it is a subclass. # This will eventually require a repartition + recompile with self.assertRaisesRegex( diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 310f7f4c79de..47e4481ef6af 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -3671,7 +3671,7 @@ class AssociativeScanModels: # Check if val is a list and if it has the same length as combine_fn # If so, then use the individual elements. # If not, duplicate the first element. - if type(val) == list and len(val) == chain_len: + if type(val) is list and len(val) == chain_len: kwargs_el[key] = val[ind] else: kwargs_el[key] = val diff --git a/test/fx/test_fx_split.py b/test/fx/test_fx_split.py index 7338dd0314a1..8d2b120e534a 100644 --- a/test/fx/test_fx_split.py +++ b/test/fx/test_fx_split.py @@ -296,7 +296,7 @@ class TestSplitOutputType(TestCase): gm_output = module(inputs) split_gm_output = split_gm(inputs) - self.assertTrue(type(gm_output) == type(split_gm_output)) + self.assertTrue(type(gm_output) is type(split_gm_output)) self.assertTrue(torch.equal(gm_output, split_gm_output)) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 3f5455f0748a..0ee60f978127 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -514,8 +514,8 @@ class TestSubgraphRewriter(JitTestCase): symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): if n.op == "placeholder": - assert n.type == int - assert m.type == int + assert n.type is int + assert m.type is int def test_subgraph_rewriter_replace_consecutive_submodules(self): def f(x): diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index cac7586e8d35..746a2808c901 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -81,9 +81,9 @@ class BinaryFoldingTemplate(TestCase): out_optimized = torch.compile(mod_eager) inps = [4, 3, 4] - if module == nn.Conv2d: + if module is nn.Conv2d: inps.append(inps[-1]) - if module == nn.Conv3d: + if module is nn.Conv3d: inps.append(inps[-1]) inps.append(inps[-1]) @@ -195,9 +195,9 @@ class BinaryFoldingTemplate(TestCase): ) inps = [4, 3, 4] - if module[0] == nn.Conv2d: + if module[0] is nn.Conv2d: inps.append(inps[-1]) - if module[0] == nn.Conv3d: + if module[0] is nn.Conv3d: inps.append(inps[-1]) inps.append(inps[-1]) diff --git a/test/inductor/test_cache.py b/test/inductor/test_cache.py index 3ff7d3593506..d7ac4df3bf07 100644 --- a/test/inductor/test_cache.py +++ b/test/inductor/test_cache.py @@ -106,9 +106,9 @@ class TestMixin: return keys def key(self: Self, key_type: type[icache.Key]) -> icache.Key: - if key_type == str: + if key_type is str: return f"s{randint(0, 2**32)}" - elif key_type == int: + elif key_type is int: return randint(0, 2**32) elif key_type == tuple[Any, ...]: return (self.key(str), self.key(int)) @@ -125,13 +125,13 @@ class TestMixin: return values def value(self: Self, value_type: type[icache.Value]) -> icache.Value: - if value_type == str: + if value_type is str: return f"s{randint(0, 2**32)}" - elif value_type == int: + elif value_type is int: return randint(0, 2**32) elif value_type == tuple[Any, ...]: return (self.value(str), self.value(int)) - elif value_type == bytes: + elif value_type is bytes: return self.value(str).encode() elif value_type == dict[Any, Any]: return { diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 97b1ee2f1bc0..55f8dd5d24eb 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -88,7 +88,7 @@ def _check_if_instances_equal(op1, op2) -> bool: if isinstance(op1, (list | tuple)): return tuple(op1) == tuple(op2) - if type(op1) != type(op2): + if type(op1) is not type(op2): return False # some classes have __eq__ defined but they may be insufficient diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index 2bcd333cbf2a..86b6b6ac8a0d 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -127,11 +127,11 @@ class EfficientConvBNEvalTemplate(TestCase): spatial_d = ( 4 if issubclass(module[0], nn.modules.conv._ConvTransposeNd) else 96 ) - if module[0] == nn.Conv1d or module[0] == nn.ConvTranspose1d: + if module[0] is nn.Conv1d or module[0] is nn.ConvTranspose1d: inps += [spatial_d] * 1 - if module[0] == nn.Conv2d or module[0] == nn.ConvTranspose2d: + if module[0] is nn.Conv2d or module[0] is nn.ConvTranspose2d: inps += [spatial_d] * 2 - if module[0] == nn.Conv3d or module[0] == nn.ConvTranspose3d: + if module[0] is nn.Conv3d or module[0] is nn.ConvTranspose3d: inps += [spatial_d] * 3 inp = torch.rand(inps).to(self.device) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e3c551213277..2b742d92ee4c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -514,11 +514,11 @@ def check_model( # print("Graph", graph) if check_has_compiled: assert called, "Ran graph without calling compile_fx" - assert type(actual) == type(correct) + assert type(actual) is type(correct) if isinstance(actual, (tuple, list)): assert len(actual) == len(correct) assert all( - type(actual_item) == type(correct_item) + type(actual_item) is type(correct_item) for actual_item, correct_item in zip(actual, correct) ) diff --git a/test/inductor/test_utils.py b/test/inductor/test_utils.py index 349160a1e6c6..7d23457732a1 100644 --- a/test/inductor/test_utils.py +++ b/test/inductor/test_utils.py @@ -198,7 +198,7 @@ class TestUtils(TestCase): @dtypes(torch.float16, torch.bfloat16, torch.float32) def test_get_device_tflops(self, dtype): ret = get_device_tflops(dtype) - self.assertTrue(type(ret) == float) + self.assertTrue(type(ret) is float) instantiate_device_type_tests(TestUtils, globals()) diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 8258124680b4..ca1172a2ce7e 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -2083,9 +2083,9 @@ class TestFrozenOptimizations(JitTestCase): mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() inps = [4, 3, 4] - if modules[0] == nn.Conv2d: + if modules[0] is nn.Conv2d: inps.append(inps[-1]) - if modules[0] == nn.Conv3d: + if modules[0] is nn.Conv3d: inps.append(inps[-1]) inps.append(inps[-1]) @@ -2224,9 +2224,9 @@ class TestFrozenOptimizations(JitTestCase): mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval() inps = [4, 3, 4] - if module == nn.Conv2d: + if module is nn.Conv2d: inps.append(inps[-1]) - if module == nn.Conv3d: + if module is nn.Conv3d: inps.append(inps[-1]) inps.append(inps[-1]) @@ -2366,10 +2366,10 @@ class TestFrozenOptimizations(JitTestCase): mod_eager = LinearBN(32, 32).eval() inps = [3, 32] - if modules[1] == nn.BatchNorm2d: + if modules[1] is nn.BatchNorm2d: inps.append(inps[-1]) inps.append(inps[-1]) - if modules[1] == nn.BatchNorm3d: + if modules[1] is nn.BatchNorm3d: inps.append(inps[-1]) inps.append(inps[-1]) inps.append(inps[-1]) @@ -2429,14 +2429,14 @@ class TestFrozenOptimizations(JitTestCase): N, C = 3, bn_in input_shape = [N, C] - if modules[1] == nn.BatchNorm1d: + if modules[1] is nn.BatchNorm1d: H = linear_in input_shape.append(H) - elif modules[1] == nn.BatchNorm2d: + elif modules[1] is nn.BatchNorm2d: H, W = 4, linear_in input_shape.append(H) input_shape.append(W) - elif modules[1] == nn.BatchNorm3d: + elif modules[1] is nn.BatchNorm3d: D, H, W = 4, 4, linear_in input_shape.append(D) input_shape.append(H) @@ -2504,10 +2504,10 @@ class TestFrozenOptimizations(JitTestCase): mod_eager = LinearBN(32, 32).cuda().eval() inps = [3, 32] - if modules[1] == nn.BatchNorm2d: + if modules[1] is nn.BatchNorm2d: inps.append(inps[-1]) inps.append(inps[-1]) - if modules[1] == nn.BatchNorm3d: + if modules[1] is nn.BatchNorm3d: inps.append(inps[-1]) inps.append(inps[-1]) inps.append(inps[-1]) @@ -2757,9 +2757,9 @@ class TestFrozenOptimizations(JitTestCase): for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]): mod = module(3, 32, kernel_size=3, stride=2).eval() inps = [4, 3, 4] - if module == nn.Conv2d: + if module is nn.Conv2d: inps.append(inps[-1]) - if module == nn.Conv3d: + if module is nn.Conv3d: inps.append(inps[-1]) inps.append(inps[-1]) @@ -2997,7 +2997,7 @@ class TestFrozenOptimizations(JitTestCase): mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() inps = [5, 3, 4, 4] - if conv == nn.Conv3d: + if conv is nn.Conv3d: inps.append(inps[-1]) inp = torch.rand(inps).cuda() diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index 8f34a1c75b6d..c1a010dcfb94 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -210,7 +210,7 @@ class TestTyping(JitTestCase): li_1, li_2, li_3 = stuff4([True]) li_3 = li_3[0] for li in [li_1, li_2, li_3]: - self.assertTrue(type(li[0]) == bool) + self.assertTrue(type(li[0]) is bool) def test_nested_list(self): def foo(z): diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 25211db3fe49..fe93775f0830 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -3839,9 +3839,9 @@ class TestConvolutionNNDeviceType(NNTestCase): # This is because we have N111 weight that cannot handle # the ambiguous memory_format if w_f == torch.channels_last: - if layer == nn.Conv2d and filter_size * c != 1: + if layer is nn.Conv2d and filter_size * c != 1: output_format = torch.channels_last - if layer == nn.ConvTranspose2d and filter_size * k != 1: + if layer is nn.ConvTranspose2d and filter_size * k != 1: output_format = torch.channels_last self._run_conv( layer, diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index 8ce1f03c0a84..074ac6273689 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -474,8 +474,8 @@ def load_torch_function_handler(cls, func, types, args=(), kwargs=None): f"Expected isinstance(src, {cls}) but got {type(src)}" ) assert ( - type(dest) == torch.Tensor - or type(dest) == torch.nn.Parameter + type(dest) is torch.Tensor + or type(dest) is torch.nn.Parameter or issubclass(cls, type(dest)) ) if assign: diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index f2e12d2f64e6..0840eeb1be42 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3053,7 +3053,7 @@ class TestQuantizedOps(TestCase): lstm_quantized = torch.ao.quantization.convert( lstm_prepared, convert_custom_config_dict=custom_config_dict ) - assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM + assert type(lstm_quantized[0]) is torch.ao.nn.quantized.LSTM qy = lstm_quantized(qx) snr = _snr(y, qy) diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index d20a2a708ec1..73ed76989591 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -138,7 +138,7 @@ class TestObserver(QuantizationTestCase): # Calculate Qparams should return with a warning for observers with no data qparams = myobs.calculate_qparams() input_scale = 2**16 if qdtype is torch.qint32 else 1 - if type(myobs) == MinMaxObserver: + if type(myobs) is MinMaxObserver: x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) * input_scale else: @@ -201,7 +201,7 @@ class TestObserver(QuantizationTestCase): [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], ] ) - if type(myobs) == MovingAveragePerChannelMinMaxObserver: + if type(myobs) is MovingAveragePerChannelMinMaxObserver: # Scaling the input tensor to model change in min/max values # across batches result = myobs(0.5 * x) diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index d4ae27677dd7..6b5fc67dcc9d 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -599,7 +599,7 @@ class TestFakeQuantizeOps(TestCase): # Output of fake quant is not identical to input Y = fq_module(X) self.assertNotEqual(Y, X) - if type(fq_module) == _LearnableFakeQuantize: + if type(fq_module) is _LearnableFakeQuantize: fq_module.toggle_fake_quant(False) else: torch.ao.quantization.disable_fake_quant(fq_module) @@ -613,7 +613,7 @@ class TestFakeQuantizeOps(TestCase): scale = fq_module.scale.detach().clone() zero_point = fq_module.zero_point.detach().clone() - if type(fq_module) == _LearnableFakeQuantize: + if type(fq_module) is _LearnableFakeQuantize: fq_module.toggle_observer_update(False) fq_module.toggle_fake_quant(True) else: @@ -625,7 +625,7 @@ class TestFakeQuantizeOps(TestCase): # Observer is disabled, scale and zero-point do not change self.assertEqual(fq_module.scale, scale) self.assertEqual(fq_module.zero_point, zero_point) - if type(fq_module) == _LearnableFakeQuantize: + if type(fq_module) is _LearnableFakeQuantize: fq_module.toggle_observer_update(True) else: torch.ao.quantization.enable_observer(fq_module) diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index c5ce0659f55f..da67f19488a4 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -241,7 +241,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd): Args: `mod` a float module, either produced by torch.ao.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( "qat." + cls.__name__ + ".from_float only works for " @@ -1264,8 +1264,8 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase): mp = prepare_qat(m) mp(data) mq = convert(mp) - self.assertTrue(type(mq[1]) == nnq.Linear) - self.assertTrue(type(mq[2]) == nn.Identity) + self.assertTrue(type(mq[1]) is nnq.Linear) + self.assertTrue(type(mq[2]) is nn.Identity) @skipIfNoXNNPACK @override_qengines diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 80ab0f1e8618..51bce95e30ab 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -1823,7 +1823,7 @@ class TestFxModelReportVisualizer(QuantizationTestCase): plottable_set = set() for feature_name in b_1_linear_features: - if type(b_1_linear_features[feature_name]) == torch.Tensor: + if type(b_1_linear_features[feature_name]) is torch.Tensor: plottable_set.add(feature_name) returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names() diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index e38c56da2a71..f6f1128e422c 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -826,7 +826,7 @@ class TestFuseFx(QuantizationTestCase): # check conv module has two inputs named_modules = dict(m.named_modules()) for node in m.graph.nodes: - if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d: + if node.op == "call_module" and type(named_modules[node.target]) is torch.nn.Conv2d: self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments") def test_fusion_pattern_with_matchallnode(self): @@ -917,7 +917,7 @@ class TestQuantizeFx(QuantizationTestCase): m = torch.fx.symbolic_trace(M()) modules = dict(m.named_modules()) for n in m.graph.nodes: - if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: + if n.op == 'call_module' and type(modules[n.target]) is nn.ReLU: self.assertTrue(_is_match(modules, n, pattern)) def test_pattern_match_constant(self): diff --git a/test/quantization/fx/test_subgraph_rewriter.py b/test/quantization/fx/test_subgraph_rewriter.py index 41c085b34a04..e410f93803d6 100644 --- a/test/quantization/fx/test_subgraph_rewriter.py +++ b/test/quantization/fx/test_subgraph_rewriter.py @@ -454,8 +454,8 @@ class TestSubgraphRewriter(JitTestCase): symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): if n.op == 'placeholder': - assert n.type == int - assert m.type == int + assert n.type is int + assert m.type is int def test_subgraph_writer_replace_consecutive_submodules(self): diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 6c83ab1a869e..9e2e690c21d7 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -332,7 +332,7 @@ class TestHelperModules: ) -> None: super().__init__() self.linear = nn.Linear(4, 4, bias=use_bias) - if postop == nn.GELU: + if postop is nn.GELU: self.postop = postop(approximate=post_op_algo) else: self.postop = postop(inplace=inplace_postop) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index fbbcd831397a..406242964d1c 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -4162,7 +4162,7 @@ class TestBinaryUfuncs(TestCase): for i in complex_exponents if exp_dtype.is_complex else exponents: out_dtype_scalar_exp = ( torch.complex128 - if base_dtype.is_complex or type(i) == complex + if base_dtype.is_complex or type(i) is complex else torch.float64 ) expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) @@ -4190,7 +4190,7 @@ class TestBinaryUfuncs(TestCase): for i in complex_exponents if base_dtype.is_complex else exponents: out_dtype_scalar_base = ( torch.complex128 - if exp_dtype.is_complex or type(i) == complex + if exp_dtype.is_complex or type(i) is complex else torch.float64 ) expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) @@ -4205,9 +4205,9 @@ class TestBinaryUfuncs(TestCase): def test_float_power_exceptions(self, device): def _promo_helper(x, y): for i in (x, y): - if type(i) == complex: + if type(i) is complex: return torch.complex128 - elif type(i) == torch.Tensor and i.is_complex(): + elif type(i) is torch.Tensor and i.is_complex(): return torch.complex128 return torch.double diff --git a/test/test_datapipe.py b/test/test_datapipe.py index cb8dd252ec4b..e92fa2b0615d 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -2478,7 +2478,7 @@ class TestTyping(TestCase): else: self.assertFalse(issubinstance(d, S)) for t in basic_type: - if type(d) == t: + if type(d) is t: self.assertTrue(issubinstance(d, t)) else: self.assertFalse(issubinstance(d, t)) @@ -2577,7 +2577,7 @@ class TestTyping(TestCase): self.assertTrue(issubclass(DP4, IterDataPipe)) dp4 = DP4() - self.assertTrue(dp4.type.param == tuple) + self.assertTrue(dp4.type.param is tuple) class DP5(IterDataPipe): r"""DataPipe without type annotation""" @@ -2601,7 +2601,7 @@ class TestTyping(TestCase): self.assertTrue(issubclass(DP6, IterDataPipe)) dp6 = DP6() - self.assertTrue(dp6.type.param == int) + self.assertTrue(dp6.type.param is int) class DP7(IterDataPipe[Awaitable[T_co]]): r"""DataPipe with abstract base class""" diff --git a/test/test_decomp.py b/test/test_decomp.py index a534b643997b..e7e86dda6b8e 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -878,7 +878,7 @@ def forward(self, scores_1, mask_1, value_1): zip(real_out, decomp_out, real_out_double) ): if not isinstance(orig, torch.Tensor): - assert type(orig) == type(decomp) + assert type(orig) is type(decomp) assert orig == decomp continue op_assert_ref( @@ -895,7 +895,7 @@ def forward(self, scores_1, mask_1, value_1): else: for orig, decomp in zip(real_out, decomp_out): if not isinstance(orig, torch.Tensor): - assert type(orig) == type(decomp) + assert type(orig) is type(decomp) assert orig == decomp continue op_assert_equal( diff --git a/test/test_jit.py b/test/test_jit.py index 83407e25d0b5..fb7088a2875f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2887,9 +2887,9 @@ graph(%Ra, %Rb): self.assertTrue(hasattr(input, 'type')) self.assertTrue(input.type() is not None) self.assertTrue(hasattr(block, 'returnNode')) - self.assertTrue(type(block.returnNode()) == torch._C.Node) + self.assertTrue(type(block.returnNode()) is torch._C.Node) self.assertTrue(hasattr(block, 'paramNode')) - self.assertTrue(type(block.paramNode()) == torch._C.Node) + self.assertTrue(type(block.paramNode()) is torch._C.Node) self.assertTrue(tested_blocks) def test_export_opnames(self): @@ -6510,7 +6510,7 @@ a") if isinstance(res_python, Exception): continue - if type(res_python) == type(res_script): + if type(res_python) is type(res_script): if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])): continue if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script): @@ -8646,7 +8646,7 @@ dedent """ args = args + [1, 1.5] def isBool(arg): - return type(arg) == bool or (type(arg) == str and "torch.bool" in arg) + return type(arg) is bool or (type(arg) is str and "torch.bool" in arg) for op in ops: for first_arg in args: @@ -8655,7 +8655,7 @@ dedent """ if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)): continue # div is not implemented correctly for mixed-type or int params - if (op == 'div' and (type(first_arg) != type(second_arg) or + if (op == 'div' and (type(first_arg) is not type(second_arg) or isinstance(first_arg, int) or (isinstance(first_arg, str) and 'int' in first_arg))): continue @@ -8671,7 +8671,7 @@ dedent """ graph = cu.func.graph torch._C._jit_pass_complete_shape_analysis(graph, (), False) # use dim=-1 to represent a python/jit scalar. - dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim() + dim = -1 if type(first_arg) is not str and type(second_arg) is not str else non_jit_result.dim() dtype = non_jit_result.dtype # jit only supports int/float scalars. if dim < 0: diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 85c3b4d2cb3c..08feece4f712 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -211,9 +211,9 @@ def autograd_sharing(queue, ready, master_modified, device, is_parameter): is_ok &= var.grad is None is_ok &= not var._backward_hooks if is_parameter: - is_ok &= type(var) == Parameter + is_ok &= type(var) is Parameter else: - is_ok &= type(var) == torch.Tensor + is_ok &= type(var) is torch.Tensor var._grad = torch.ones(5, 5, device=device) queue.put(is_ok) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 20502eaafa61..ca7e65fc6247 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -596,7 +596,7 @@ class TestNumPyInterop(TestCase): if ( dtype == torch.complex64 and torch.is_tensor(t) - and type(a) == np.complex64 + and type(a) is np.complex64 ): # TODO: Imaginary part is dropped in this case. Need fix. # https://github.com/pytorch/pytorch/issues/43579 diff --git a/test/test_reductions.py b/test/test_reductions.py index 0e47e9b60a6e..7aabe08abef2 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -3327,7 +3327,7 @@ class TestReductions(TestCase): """ def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density): def to_np(t): - if type(t) == list: + if type(t) is list: return list(map(to_np, t)) if not torch.is_tensor(t): return t diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 59d856ec4fc9..5a641fb3206a 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -968,7 +968,7 @@ class TestTypePromotion(TestCase): except Exception as e: expected = e - same_result = (type(expected) == type(actual)) and expected == actual + same_result = (type(expected) is type(actual)) and expected == actual # Note: An "undesired failure," as opposed to an "expected failure" # is both expected (we know the test will fail) and @@ -1128,7 +1128,7 @@ class TestTypePromotion(TestCase): maxs = (max_t, max_t[0], max_t[0].item()) inp = make_tensor((S,), dtype0) for min_v, max_v in itertools.product(mins, maxs): - if type(max_v) != type(min_v): + if type(max_v) is not type(min_v): continue if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0: continue # 0d tensors go to scalar overload, and it's tested separately diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py index 75bf5c0fc628..c6b2d14aef6d 100644 --- a/test/torch_np/numpy_tests/core/test_numeric.py +++ b/test/torch_np/numpy_tests/core/test_numeric.py @@ -2384,7 +2384,7 @@ class TestLikeFuncs(TestCase): b = a[:, ::2] # Ensure b is not contiguous. kwargs = {"fill_value": ""} if likefunc == np.full_like else {} result = likefunc(b, dtype=dtype, **kwargs) - if dtype == str: + if dtype is str: assert result.strides == (16, 4) else: # dtype is bytes diff --git a/test/torch_np/numpy_tests/core/test_scalarmath.py b/test/torch_np/numpy_tests/core/test_scalarmath.py index 84b1e99cb931..ea7621e97546 100644 --- a/test/torch_np/numpy_tests/core/test_scalarmath.py +++ b/test/torch_np/numpy_tests/core/test_scalarmath.py @@ -925,7 +925,7 @@ class TestScalarSubclassingMisc(TestCase): # inheritance has to override, or this is correctly lost: res = op(myf_simple1(1), myf_simple2(2)) - assert type(res) == sctype or type(res) == np.bool_ + assert type(res) is sctype or type(res) is np.bool_ assert op(myf_simple1(1), myf_simple2(2)) == op(1, 2) # inherited # Two independent subclasses do not really define an order. This could @@ -955,7 +955,7 @@ class TestScalarSubclassingMisc(TestCase): assert op(myt(1), np.float64(2)) == __op__ assert op(np.float64(1), myt(2)) == __rop__ - if op in {operator.mod, operator.floordiv} and subtype == complex: + if op in {operator.mod, operator.floordiv} and subtype is complex: return # module is not support for complex. Do not test. if __rop__ == __op__: @@ -968,11 +968,11 @@ class TestScalarSubclassingMisc(TestCase): res = op(myt(1), np.float16(2)) expected = op(subtype(1), np.float16(2)) assert res == expected - assert type(res) == type(expected) + assert type(res) is type(expected) res = op(np.float32(2), myt(1)) expected = op(np.float32(2), subtype(1)) assert res == expected - assert type(res) == type(expected) + assert type(res) is type(expected) if __name__ == "__main__": diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py index f8fa81bca63e..f3e42294a149 100644 --- a/test/torch_np/numpy_tests/linalg/test_linalg.py +++ b/test/torch_np/numpy_tests/linalg/test_linalg.py @@ -937,7 +937,7 @@ class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @instantiate_parametrized_tests class TestDet(DetCases, TestCase): def test_zero(self): - # NB: comment out tests of type(det) == double : we return zero-dim arrays + # NB: comment out tests of type(det) is double : we return zero-dim arrays assert_equal(linalg.det([[0.0]]), 0.0) # assert_equal(type(linalg.det([[0.0]])), double) assert_equal(linalg.det([[0.0j]]), 0.0) @@ -1103,7 +1103,7 @@ class TestMatrixPower(TestCase): for mat in self.rshft_all: tz(mat.astype(dt)) - if dt != object: + if dt is not object: tz(self.stacked.astype(dt)) @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) @@ -1115,7 +1115,7 @@ class TestMatrixPower(TestCase): for mat in self.rshft_all: tz(mat.astype(dt)) - if dt != object: + if dt is not object: tz(self.stacked.astype(dt)) @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) @@ -1128,7 +1128,7 @@ class TestMatrixPower(TestCase): for mat in self.rshft_all: tz(mat.astype(dt)) - if dt != object: + if dt is not object: tz(self.stacked.astype(dt)) @parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"]) diff --git a/test/torch_np/test_ndarray_methods.py b/test/torch_np/test_ndarray_methods.py index e32720d986eb..f94b03f1f6e5 100644 --- a/test/torch_np/test_ndarray_methods.py +++ b/test/torch_np/test_ndarray_methods.py @@ -661,7 +661,7 @@ class TestIter(TestCase): # numpy generates array scalars, we do 0D arrays a = np.arange(5) lst = list(a) - assert all(type(x) == np.ndarray for x in lst), f"{[type(x) for x in lst]}" + assert all(type(x) is np.ndarray for x in lst), f"{[type(x) for x in lst]}" assert all(x.ndim == 0 for x in lst) def test_iter_2d(self): @@ -669,7 +669,8 @@ class TestIter(TestCase): a = np.arange(5)[None, :] lst = list(a) assert len(lst) == 1 - assert type(lst[0]) == np.ndarray + # FIXME: "is" cannot be used here because dynamo fails + assert type(lst[0]) == np.ndarray # noqa: E721 assert_equal(lst[0], np.arange(5)) diff --git a/test/torch_np/test_nep50_examples.py b/test/torch_np/test_nep50_examples.py index 1c27d8702875..d89a7a390e34 100644 --- a/test/torch_np/test_nep50_examples.py +++ b/test/torch_np/test_nep50_examples.py @@ -94,7 +94,7 @@ class TestNEP50Table(TestCase): def test_nep50_exceptions(self, example): old, new = examples[example] - if new == Exception: + if new is Exception: with assert_raises(OverflowError): eval(example) diff --git a/tools/experimental/torchfuzz/tensor_fuzzer.py b/tools/experimental/torchfuzz/tensor_fuzzer.py index 4519e2e90b13..0357d6cbca18 100644 --- a/tools/experimental/torchfuzz/tensor_fuzzer.py +++ b/tools/experimental/torchfuzz/tensor_fuzzer.py @@ -554,7 +554,7 @@ def fuzz_scalar(spec, seed: Optional[int] = None) -> Union[float, int, bool, com def specs_compatible(spec1: Spec, spec2: Spec) -> bool: """Check if two specifications are compatible (one can be used where the other is expected).""" - if type(spec1) != type(spec2): + if type(spec1) is not type(spec2): return False if isinstance(spec1, ScalarSpec): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 597c28ad0029..506f1b408ae7 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2842,7 +2842,7 @@ def _index_add( if alpha != 1: python_type = utils.dtype_to_type(x.dtype) torch._check( - python_type == bool + python_type is bool or utils.is_weakly_lesser_type(type(alpha), python_type), lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", ) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index fb27d7db399c..4ac9fa00f1ad 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -295,7 +295,7 @@ class PyCodegen: output.extend(create_call_function(2, False)) elif ( isinstance(value, SymNodeVariable) - and value.python_type() == float + and value.python_type() is float and not self.tx.export ): # This is a little unusual; force the output convention to be a diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b58af46d0ef1..401fa6bf27e4 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -4182,7 +4182,7 @@ def make_torch_function_mode_stack_guard( return False for ty, mode in zip(types, cur_stack): - if ty != type(mode): + if ty is not type(mode): return False return True diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index a4f2d9b8d2c7..d331f1238b3c 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1361,7 +1361,7 @@ class TensorVariable(VariableTracker): if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( len(args) >= 1 and all( - isinstance(a, ConstantVariable) and a.python_type() == int for a in args + isinstance(a, ConstantVariable) and a.python_type() is int for a in args ) ): from ..symbolic_convert import InstructionTranslator diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 416619cee029..cc33c7e3aba9 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -64,14 +64,14 @@ def _staged_schema(): ) elif o := typing.get_origin(t): # Lemme know if there's a better way to do this. - if o == list: + if o is list: yaml_head, cpp_head, thrift_head, thrift_tail = ( "List", "std::vector", "list<", ">", ) - elif o == dict: + elif o is dict: yaml_head, cpp_head, thrift_head, thrift_tail = ( "Dict", "std::unordered_map", @@ -81,7 +81,7 @@ def _staged_schema(): elif o == Union: assert level == 0, "Optional is only supported at the top level." args = typing.get_args(t) - assert len(args) == 2 and args[1] == type(None) + assert len(args) == 2 and args[1] is type(None) yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1) return ( f"Optional[{yaml_type}]", diff --git a/torch/_higher_order_ops/partitioner.py b/torch/_higher_order_ops/partitioner.py index 81ad53b37339..2a21601aa9d9 100644 --- a/torch/_higher_order_ops/partitioner.py +++ b/torch/_higher_order_ops/partitioner.py @@ -83,7 +83,7 @@ class HopPartitionedGraph: val1: Union[torch.SymInt, torch.Tensor], val2: Union[torch.SymInt, torch.Tensor], ) -> bool: - if type(val1) != type(val2): + if type(val1) is not type(val2): return False if isinstance(val1, torch.SymInt) and isinstance(val2, torch.SymInt): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index d7f69a73b336..64e0fa196d6e 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1211,7 +1211,7 @@ class CppVecOverrides(CppOverrides): return wrapper for name, method in vars(CppVecOverrides).items(): - if getattr(method, "__class__", None) == staticmethod and name not in [ + if getattr(method, "__class__", None) is staticmethod and name not in [ "masked", "index_expr", ]: diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 69216c8f5c5e..403e1c2eca9e 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -220,15 +220,15 @@ class SamplingMethod(Enum): if field_name in TYPE_OVERRIDES: return random.choice(TYPE_OVERRIDES[field_name]) - if type_hint == bool: + if type_hint is bool: return random.choice([True, False]) if random_sample else not default - elif type_hint == int: + elif type_hint is int: # NOTE initially tried to use negation of the value, but it doesn't work because most types are ints # when they should be natural numbers + zero. Python types to cover these values aren't super convenient. return random.randint(0, 1000) - elif type_hint == float: + elif type_hint is float: return random.uniform(0, 1000) - elif type_hint == str: + elif type_hint is str: characters = string.ascii_letters + string.digits + string.punctuation return "".join( random.choice(characters) for _ in range(random.randint(1, 20)) @@ -306,7 +306,7 @@ class SamplingMethod(Enum): new_type = random.choice(type_hint.__args__) else: new_type = random.choice( - [t for t in type_hint.__args__ if t != type(default)] + [t for t in type_hint.__args__ if t is not type(default)] ) try: new_default = new_type() diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 26f7c0abd528..87fe5836b147 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1208,7 +1208,7 @@ def safe_grad_filter(message, category, filename, lineno, file=None, line=None) def user_warning_filter( message, category, filename, lineno, file=None, line=None ) -> bool: - return category != UserWarning + return category is not UserWarning @contextlib.contextmanager diff --git a/torch/_numpy/_reductions_impl.py b/torch/_numpy/_reductions_impl.py index 4afc217ebd4b..a4ebc094a728 100644 --- a/torch/_numpy/_reductions_impl.py +++ b/torch/_numpy/_reductions_impl.py @@ -428,7 +428,7 @@ def percentile( interpolation: NotImplementedType = None, ): # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32 - if _dtypes_impl.python_type_for_torch(q.dtype) == int: + if _dtypes_impl.python_type_for_torch(q.dtype) is int: q = q.to(_dtypes_impl.default_dtypes().float_dtype) qq = q / 100.0 diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index c5a845208ac6..13d6efd4ac67 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1179,7 +1179,7 @@ def add( if alpha is not None: dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] python_type = utils.dtype_to_type(dtype) - if python_type != bool and not utils.is_weakly_lesser_type( + if python_type is not bool and not utils.is_weakly_lesser_type( type(alpha), python_type ): msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" diff --git a/torch/_utils.py b/torch/_utils.py index c7b63525073a..87d17c374de3 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -755,7 +755,7 @@ class ExceptionWrapper: # Format a message such as: "Caught ValueError in DataLoader worker # process 2. Original Traceback:", followed by the traceback. msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute - if self.exc_type == KeyError: + if self.exc_type is KeyError: # KeyError calls repr() on its argument (usually a dict key). This # makes stack traces unreadable. It will not be changed in Python # (https://bugs.python.org/issue2651), so we work around it. diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index b6d93c164aa5..168f07ee33a0 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -317,7 +317,7 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]: node.target in (torch.add, torch.ops.quantized.add, operator.add) or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) ): - result = [i for i in range(2) if type(node.args[i]) == Node] + result = [i for i in range(2) if type(node.args[i]) is Node] return result return [0] diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 739673a0997e..fa8e7d53e6b0 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -589,7 +589,7 @@ def _match_static_pattern( # Handle cases where the node is wrapped in a ReLU if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( - ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU + ref_node.op == "call_module" and type(_get_module(ref_node, modules)) is nn.ReLU ): relu_node = ref_node ref_node = relu_node.args[0] @@ -724,7 +724,7 @@ def _lower_static_weighted_ref_module( # If so, we replace the entire fused module with the corresponding quantized module if ref_class in STATIC_LOWER_FUSED_MODULE_MAP: inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class] - if type(ref_module[0]) != inner_ref_class: # type: ignore[index] + if type(ref_module[0]) is not inner_ref_class: # type: ignore[index] continue else: q_class = STATIC_LOWER_MODULE_MAP[ref_class] @@ -786,7 +786,7 @@ def _lower_static_weighted_ref_module_with_two_inputs( inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ ref_class ] - if type(ref_module[0]) != inner_ref_class: # type: ignore[index] + if type(ref_module[0]) is not inner_ref_class: # type: ignore[index] continue else: continue @@ -846,7 +846,7 @@ def _lower_dynamic_weighted_ref_module(model: GraphModule): ref_class = type(ref_module) if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP: inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class] - if type(ref_module[0]) != inner_ref_class: + if type(ref_module[0]) is not inner_ref_class: continue else: q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment] @@ -1008,7 +1008,7 @@ def _lower_dynamic_weighted_ref_functional( func_node.op == "call_function" and func_node.target == F.relu or func_node.op == "call_module" - and type(modules[str(func_node.target)]) == torch.nn.ReLU + and type(modules[str(func_node.target)]) is torch.nn.ReLU ): relu_node = func_node func_node = relu_node.args[0] diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py index 1f127f8062aa..656206d161c9 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -132,7 +132,7 @@ class ModelReportVisualizer: # if we need plottable, ensure type of val is tensor if ( not plottable_features_only - or type(feature_dict[feature_name]) == torch.Tensor + or type(feature_dict[feature_name]) is torch.Tensor ): unique_feature_names.add(feature_name) diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 7cbca8a212ab..dc488d068cab 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -704,7 +704,7 @@ def _maybe_get_custom_module_lstm_from_node_arg( return a.op == "call_function" and a.target == operator.getitem def match_tuple(a): - return a.op == "call_function" and a.target == tuple + return a.op == "call_function" and a.target is tuple def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]: """ @@ -797,7 +797,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph): # Iterate through users of this node to find tuple/getitem nodes to match for user in node.users: - if user.op == "call_function" and user.target == tuple: + if user.op == "call_function" and user.target is tuple: for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type] if user_arg == node: index_stack.append(i) @@ -826,7 +826,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph): for pattern in matched_patterns: first_tuple = pattern[0] last_getitem = pattern[-1] - assert first_tuple.op == "call_function" and first_tuple.target == tuple + assert first_tuple.op == "call_function" and first_tuple.target is tuple assert ( last_getitem.op == "call_function" and last_getitem.target == operator.getitem diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 6c78062ba399..73375d4ee144 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -699,12 +699,12 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): state_dict_config = state_dict_config_type() if optim_state_dict_config is None: optim_state_dict_config = optim_state_dict_config_type() - if state_dict_config_type != type(state_dict_config): + if state_dict_config_type is not type(state_dict_config): raise RuntimeError( f"Expected state_dict_config of type {state_dict_config_type} " f"but got {type(state_dict_config)}" ) - if optim_state_dict_config_type != type(optim_state_dict_config): + if optim_state_dict_config_type is not type(optim_state_dict_config): raise RuntimeError( f"Expected optim_state_dict_config of type {optim_state_dict_config_type} " f"but got {type(optim_state_dict_config)}" diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index 759d54cb8d37..b5ddeb3fffe3 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -180,12 +180,12 @@ def add_inference_rule(n: Node): t2 = n.args[1].type # handle scalar addition - if t1 == int and isinstance(t2, TensorType): + if t1 is int and isinstance(t2, TensorType): n.type = t2 return n.type # handle scalar addition - elif t2 == int and isinstance(t1, TensorType): + elif t2 is int and isinstance(t1, TensorType): n.type = t1 return n.type diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 6027c603ec1f..41e831327b41 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -542,7 +542,7 @@ def reinplace(gm, *sample_args): continue if len(node.target._schema.arguments) < 1: continue - if type(node.target._schema.arguments[0].type) != torch.TensorType: + if type(node.target._schema.arguments[0].type) is not torch.TensorType: continue # Step 1a: Check that the self argument we're attempting to reinplace diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 528750157398..c8972b005dd9 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -78,7 +78,7 @@ def issubtype(left, right, recursive=True): if getattr(right, "__origin__", None) is Generic: return True - if right == type(None): + if right is type(None): return False # Right-side type