From 1f43d17ce672ff1fca2f5eab033cb03c27132385 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sat, 18 Oct 2025 18:51:49 +0000 Subject: [PATCH] Fix self assignment (#165816) This PR removes assignments of the form `var=var`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165816 Approved by: https://github.com/jansel --- torch/_functorch/vmap.py | 2 +- torch/_inductor/fx_passes/efficient_conv_bn_eval.py | 8 ++------ torch/_inductor/tiling_utils.py | 1 - torch/_inductor/utils.py | 1 - torch/_numpy/_dtypes.py | 2 -- torch/_prims/__init__.py | 6 +----- torch/nn/utils/stateless.py | 3 --- torch/testing/_internal/distributed/rpc/jit/rpc_test.py | 2 +- 8 files changed, 5 insertions(+), 20 deletions(-) diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py index 25ffe9c525f3..465be67e41fa 100644 --- a/torch/_functorch/vmap.py +++ b/torch/_functorch/vmap.py @@ -293,7 +293,7 @@ def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs): def get_chunk_sizes(total_elems, chunk_size): - n_chunks = n_chunks = total_elems // chunk_size + n_chunks = total_elems // chunk_size chunk_sizes = [chunk_size] * n_chunks # remainder chunk remainder = total_elems % chunk_size diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 78cd317284d2..b6db1367de6e 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -108,14 +108,10 @@ def efficient_conv_bn_eval_decomposed( else: bias_on_the_fly = torch.zeros_like(bn_running_var) - if bn_weight is not None: - bn_weight = bn_weight - else: + if bn_weight is None: bn_weight = torch.ones_like(bn_running_var) - if bn_bias is not None: - bn_bias = bn_bias - else: + if bn_bias is None: bn_bias = torch.zeros_like(bn_running_var) # shape of [C_out, 1, 1, 1] in Conv2d diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 3142f97f8c40..30efae2293c8 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -477,7 +477,6 @@ def extract_normalized_read_writes( (norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze( pw_splits, red_splits, prefix="n" ) - node = node for n in list(node.get_nodes()): if not isinstance(n, torch._inductor.scheduler.SchedulerNode): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f1c7f23cf719..b7c347fd7acc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -760,7 +760,6 @@ def get_fused_kernel_name( ] else: raise NotImplementedError - sources = sources return "_".join(["fused"] + sources) diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index e955a47060ff..a429d28f30cc 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -408,8 +408,6 @@ def set_default_dtype(fp_dtype="numpy", int_dtype="numpy"): if int_dtype in ["numpy", "pytorch"]: int_dtype = torch.int64 - else: - int_dtype = int_dtype new_defaults = _dtypes_impl.DefaultDTypes( float_dtype=float_dtype, complex_dtype=complex_dtype, int_dtype=int_dtype diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index f3fd27e59139..7827aa244a2e 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -447,9 +447,7 @@ def _prim_elementwise_meta( # (but getting it wrong will cause too many casts to be inserted in traces!) if device is not None: assert dtype is not None - if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT: - dtype = dtype - elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL: dtype = torch.bool elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT: if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype): @@ -457,8 +455,6 @@ def _prim_elementwise_meta( elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: if utils.is_complex_dtype(dtype): dtype = utils.corresponding_real_dtype(dtype) - else: - dtype = dtype assert shape is not None return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value] diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index ce55641faab4..148052740922 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -103,9 +103,6 @@ def _reparametrize_module( strict: bool = False, stack_weights: bool = False, ): - parameters_and_buffers = parameters_and_buffers - stack_weights = stack_weights - if tie_weights: untied_parameters_and_buffers = _untie_named_tensors_map( module, parameters_and_buffers diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index ec2f2b949907..76c089f45800 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -85,7 +85,7 @@ class RRefAPITest: ): rref_local_value(rref) - ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,)) + ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,)) self.assertEqual(ret, torch.add(torch.ones(2, 2), 1)) @dist_init