[1/N] Simplify "in" operation for containers of a single item (#164224)

These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164224
Approved by: https://github.com/rec, https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen
2025-09-30 19:59:39 +00:00
committed by PyTorch MergeBot
parent ffc645c870
commit e30f01b5b5
16 changed files with 24 additions and 24 deletions

View File

@ -83,7 +83,7 @@ def _quantize_weight(float_wt, observer):
torch.qint8,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
elif observer.qscheme == torch.per_channel_affine_float_qparams:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float),

View File

@ -64,9 +64,7 @@ def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
return qscheme in [
torch.per_channel_affine_float_qparams,
]
return qscheme == torch.per_channel_affine_float_qparams
class FakeQuantizeBase(ABC, Module):

View File

@ -227,7 +227,7 @@ def is_getattr_tensor_metadata_node(node):
return (
node.op == "call_function"
and node.target == getattr
and node.args[1] in ["shape"]
and node.args[1] == "shape"
)

View File

@ -388,7 +388,7 @@ class UniformQuantizationObserverBase(ObserverBase):
)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.dtype in [torch.uint16]:
elif self.dtype == torch.uint16:
zero_point = zero_point.new_full(zero_point.size(), 2**15)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)

View File

@ -237,7 +237,7 @@ def _add_observer_(
for name, child in module.named_children():
# TODO remove Dropout special after codebase stable
if type_before_parametrizations(child) in [nn.Dropout]:
if type_before_parametrizations(child) is nn.Dropout:
continue
elif issubclass(
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)

View File

@ -598,7 +598,7 @@ class X86InductorQuantizer(Quantizer):
_annotate_nodes_not_quantize(linear_node)
return
input_qspec_map = {}
assert linear_node.target in (torch.ops.aten.linear.default,)
assert linear_node.target == torch.ops.aten.linear.default
has_bias = len(linear_node.args) == 3
input_index = 0
weight_index = 1
@ -1436,8 +1436,9 @@ class X86InductorQuantizer(Quantizer):
"Linear partition cannot have more than one output node"
)
linear_node = partition.output_nodes[0]
if linear_node.op != "call_function" or linear_node.target not in (
torch.ops.aten.linear.default,
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default
):
raise ValueError(f"{linear_node} is not an aten linear operator")
# skip annotation if it is already annotated
@ -1467,8 +1468,9 @@ class X86InductorQuantizer(Quantizer):
linear_node, unary_node = self._get_output_nodes_of_partitions(
[linear_partition, unary_partition]
)
if linear_node.op != "call_function" or linear_node.target not in (
torch.ops.aten.linear.default,
if (
linear_node.op != "call_function"
or linear_node.target != torch.ops.aten.linear.default
):
continue
if _skip_annotate([unary_node, linear_node], filter_fn):

View File

@ -501,9 +501,9 @@ def calculate_qmin_qmax(
quant_min, quant_max = 0, 255
elif dtype in [torch.qint32, torch.int32]:
quant_min, quant_max = -1 * (2**31), (2**31) - 1
elif dtype in [torch.uint16]:
elif dtype == torch.uint16:
quant_min, quant_max = 0, 2**16 - 1
elif dtype in [torch.int16]:
elif dtype == torch.int16:
quant_min, quant_max = -(2**15), 2**15 - 1
else:
quant_min, quant_max = 0, 15

View File

@ -420,7 +420,7 @@ class _CudaKernel:
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
# CDNA4 allows a max of 160KB shared memory
max_shared_mem = (
65536 if device_props.gcnArchName not in ["gfx950"] else 160 * 1024
65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
)
else:
max_shared_mem = getattr(

View File

@ -34,7 +34,7 @@ class LayerNorm(nn.Module):
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in [torch.contiguous_format]:
if self.data_format != torch.contiguous_format:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)

View File

@ -427,7 +427,7 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
elif op_name in {"logsumexp"}:
elif op_name == "logsumexp":
if torch.is_floating_point(input):
return torch.tensor(-torch.inf, dtype=dtype, device=device)
elif torch.is_complex(input):

View File

@ -765,7 +765,7 @@ class QuantizationTestCase(TestCase):
and not isinstance(module, _FusedModule)
):
for child in module.children():
if type(child) in [nn.Dropout]:
if type(child) is nn.Dropout:
continue
self.checkObservers(
child, propagate_qconfig_list, prepare_custom_config_dict

View File

@ -192,7 +192,7 @@ class Trainer:
self.hybrid_module = HybridModel(
self.remote_em_rref,
self.remote_net_rref,
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
self.trainer_group if ddp_mode == DdpMode.INSIDE else None,
)
self.ddp_params, self.non_ddp_params = (
self.hybrid_module.ddp_params,

View File

@ -707,7 +707,7 @@ class DistributedTest:
self.assertNotEqual(args.get("dtype", ""), "")
per_coll_meta[collname].append(args)
if collname in {"wait"}:
if collname == "wait":
continue
self.assertEqual(args["Process Group Description"], "default_pg")
@ -7029,7 +7029,7 @@ class DistributedTest:
self.assertNotEqual(attrs.get("dtype", ""), "")
per_coll_meta[collname].append(attrs)
if collname in {"wait"}:
if collname == "wait":
continue
self.assertEqual(attrs["pg_name"], "0") # yes this is a string

View File

@ -125,7 +125,7 @@ class DebugMode(TorchDispatchMode):
_get_current_dispatch_mode(), FakeTensorMode
):
if self.record_faketensor:
if func not in {torch.ops.prim.device.default}:
if func != torch.ops.prim.device.default:
self.operators.append((func, args, kwargs, self.call_depth + 1))
elif len(types) == 0:
if self.record_realtensor:

View File

@ -103,7 +103,7 @@ class Capture:
def __getattr__(self, attrname):
if attrname == "kwarg" or attrname == "kwargs":
raise RuntimeError("no kwargs!")
if attrname in ["__deepcopy__"]:
if attrname == "__deepcopy__":
raise AttributeError
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
return result

View File

@ -783,7 +783,7 @@ class _FlopCounterMode(TorchDispatchMode):
return result, flop_counts
def _handle_higher_order_ops(self, func, types, args, kwargs):
if func not in {torch.ops.higher_order.cond, }:
if func is not torch.ops.higher_order.cond:
return NotImplemented
# The flop counter for cond counts the upper bound of flops.