mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/N] Simplify "in" operation for containers of a single item (#164224)
These issues are detected by ruff [FURB171](https://docs.astral.sh/ruff/rules/single-item-membership-test/#single-item-membership-test-furb171). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164224 Approved by: https://github.com/rec, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffc645c870
commit
e30f01b5b5
@ -83,7 +83,7 @@ def _quantize_weight(float_wt, observer):
|
||||
torch.qint8,
|
||||
)
|
||||
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
|
||||
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
|
||||
elif observer.qscheme == torch.per_channel_affine_float_qparams:
|
||||
qweight = torch.quantize_per_channel(
|
||||
float_wt,
|
||||
wt_scale.to(torch.float),
|
||||
|
@ -64,9 +64,7 @@ def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
|
||||
|
||||
|
||||
def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
|
||||
return qscheme in [
|
||||
torch.per_channel_affine_float_qparams,
|
||||
]
|
||||
return qscheme == torch.per_channel_affine_float_qparams
|
||||
|
||||
|
||||
class FakeQuantizeBase(ABC, Module):
|
||||
|
@ -227,7 +227,7 @@ def is_getattr_tensor_metadata_node(node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == getattr
|
||||
and node.args[1] in ["shape"]
|
||||
and node.args[1] == "shape"
|
||||
)
|
||||
|
||||
|
||||
|
@ -388,7 +388,7 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||
)
|
||||
else:
|
||||
zero_point = zero_point.new_full(zero_point.size(), 128)
|
||||
elif self.dtype in [torch.uint16]:
|
||||
elif self.dtype == torch.uint16:
|
||||
zero_point = zero_point.new_full(zero_point.size(), 2**15)
|
||||
elif self.qscheme == torch.per_channel_affine_float_qparams:
|
||||
scale = (max_val - min_val) / float(quant_max - quant_min)
|
||||
|
@ -237,7 +237,7 @@ def _add_observer_(
|
||||
|
||||
for name, child in module.named_children():
|
||||
# TODO remove Dropout special after codebase stable
|
||||
if type_before_parametrizations(child) in [nn.Dropout]:
|
||||
if type_before_parametrizations(child) is nn.Dropout:
|
||||
continue
|
||||
elif issubclass(
|
||||
type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)
|
||||
|
@ -598,7 +598,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
_annotate_nodes_not_quantize(linear_node)
|
||||
return
|
||||
input_qspec_map = {}
|
||||
assert linear_node.target in (torch.ops.aten.linear.default,)
|
||||
assert linear_node.target == torch.ops.aten.linear.default
|
||||
has_bias = len(linear_node.args) == 3
|
||||
input_index = 0
|
||||
weight_index = 1
|
||||
@ -1436,8 +1436,9 @@ class X86InductorQuantizer(Quantizer):
|
||||
"Linear partition cannot have more than one output node"
|
||||
)
|
||||
linear_node = partition.output_nodes[0]
|
||||
if linear_node.op != "call_function" or linear_node.target not in (
|
||||
torch.ops.aten.linear.default,
|
||||
if (
|
||||
linear_node.op != "call_function"
|
||||
or linear_node.target != torch.ops.aten.linear.default
|
||||
):
|
||||
raise ValueError(f"{linear_node} is not an aten linear operator")
|
||||
# skip annotation if it is already annotated
|
||||
@ -1467,8 +1468,9 @@ class X86InductorQuantizer(Quantizer):
|
||||
linear_node, unary_node = self._get_output_nodes_of_partitions(
|
||||
[linear_partition, unary_partition]
|
||||
)
|
||||
if linear_node.op != "call_function" or linear_node.target not in (
|
||||
torch.ops.aten.linear.default,
|
||||
if (
|
||||
linear_node.op != "call_function"
|
||||
or linear_node.target != torch.ops.aten.linear.default
|
||||
):
|
||||
continue
|
||||
if _skip_annotate([unary_node, linear_node], filter_fn):
|
||||
|
@ -501,9 +501,9 @@ def calculate_qmin_qmax(
|
||||
quant_min, quant_max = 0, 255
|
||||
elif dtype in [torch.qint32, torch.int32]:
|
||||
quant_min, quant_max = -1 * (2**31), (2**31) - 1
|
||||
elif dtype in [torch.uint16]:
|
||||
elif dtype == torch.uint16:
|
||||
quant_min, quant_max = 0, 2**16 - 1
|
||||
elif dtype in [torch.int16]:
|
||||
elif dtype == torch.int16:
|
||||
quant_min, quant_max = -(2**15), 2**15 - 1
|
||||
else:
|
||||
quant_min, quant_max = 0, 15
|
||||
|
@ -420,7 +420,7 @@ class _CudaKernel:
|
||||
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
|
||||
# CDNA4 allows a max of 160KB shared memory
|
||||
max_shared_mem = (
|
||||
65536 if device_props.gcnArchName not in ["gfx950"] else 160 * 1024
|
||||
65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
|
||||
)
|
||||
else:
|
||||
max_shared_mem = getattr(
|
||||
|
@ -34,7 +34,7 @@ class LayerNorm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.eps = eps
|
||||
self.data_format = data_format
|
||||
if self.data_format not in [torch.contiguous_format]:
|
||||
if self.data_format != torch.contiguous_format:
|
||||
raise NotImplementedError
|
||||
self.normalized_shape = (normalized_shape,)
|
||||
|
||||
|
@ -427,7 +427,7 @@ def _reduction_identity(op_name: str, input: Tensor, *args):
|
||||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||||
elif torch.is_signed(input) or dtype == torch.uint8:
|
||||
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
|
||||
elif op_name in {"logsumexp"}:
|
||||
elif op_name == "logsumexp":
|
||||
if torch.is_floating_point(input):
|
||||
return torch.tensor(-torch.inf, dtype=dtype, device=device)
|
||||
elif torch.is_complex(input):
|
||||
|
@ -765,7 +765,7 @@ class QuantizationTestCase(TestCase):
|
||||
and not isinstance(module, _FusedModule)
|
||||
):
|
||||
for child in module.children():
|
||||
if type(child) in [nn.Dropout]:
|
||||
if type(child) is nn.Dropout:
|
||||
continue
|
||||
self.checkObservers(
|
||||
child, propagate_qconfig_list, prepare_custom_config_dict
|
||||
|
@ -192,7 +192,7 @@ class Trainer:
|
||||
self.hybrid_module = HybridModel(
|
||||
self.remote_em_rref,
|
||||
self.remote_net_rref,
|
||||
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
|
||||
self.trainer_group if ddp_mode == DdpMode.INSIDE else None,
|
||||
)
|
||||
self.ddp_params, self.non_ddp_params = (
|
||||
self.hybrid_module.ddp_params,
|
||||
|
@ -707,7 +707,7 @@ class DistributedTest:
|
||||
self.assertNotEqual(args.get("dtype", ""), "")
|
||||
|
||||
per_coll_meta[collname].append(args)
|
||||
if collname in {"wait"}:
|
||||
if collname == "wait":
|
||||
continue
|
||||
|
||||
self.assertEqual(args["Process Group Description"], "default_pg")
|
||||
@ -7029,7 +7029,7 @@ class DistributedTest:
|
||||
self.assertNotEqual(attrs.get("dtype", ""), "")
|
||||
|
||||
per_coll_meta[collname].append(attrs)
|
||||
if collname in {"wait"}:
|
||||
if collname == "wait":
|
||||
continue
|
||||
|
||||
self.assertEqual(attrs["pg_name"], "0") # yes this is a string
|
||||
|
@ -125,7 +125,7 @@ class DebugMode(TorchDispatchMode):
|
||||
_get_current_dispatch_mode(), FakeTensorMode
|
||||
):
|
||||
if self.record_faketensor:
|
||||
if func not in {torch.ops.prim.device.default}:
|
||||
if func != torch.ops.prim.device.default:
|
||||
self.operators.append((func, args, kwargs, self.call_depth + 1))
|
||||
elif len(types) == 0:
|
||||
if self.record_realtensor:
|
||||
|
@ -103,7 +103,7 @@ class Capture:
|
||||
def __getattr__(self, attrname):
|
||||
if attrname == "kwarg" or attrname == "kwargs":
|
||||
raise RuntimeError("no kwargs!")
|
||||
if attrname in ["__deepcopy__"]:
|
||||
if attrname == "__deepcopy__":
|
||||
raise AttributeError
|
||||
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
|
||||
return result
|
||||
|
@ -783,7 +783,7 @@ class _FlopCounterMode(TorchDispatchMode):
|
||||
return result, flop_counts
|
||||
|
||||
def _handle_higher_order_ops(self, func, types, args, kwargs):
|
||||
if func not in {torch.ops.higher_order.cond, }:
|
||||
if func is not torch.ops.higher_order.cond:
|
||||
return NotImplemented
|
||||
|
||||
# The flop counter for cond counts the upper bound of flops.
|
||||
|
Reference in New Issue
Block a user