[BE][Easy] fix ruff rule needless-bool (SIM103) (#130206)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130206
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-07-08 00:02:13 +08:00
committed by PyTorch MergeBot
parent fa5f572748
commit 4d7bf72d93
35 changed files with 64 additions and 177 deletions

View File

@ -99,10 +99,7 @@ def check_loss(ref_loss, res_loss):
assert len(ref_loss) == len(res_loss)
length = len(ref_loss)
x = min(length, 10)
if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
return True
else:
return False
return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1
def parse_args():

View File

@ -339,16 +339,12 @@ class BenchmarkRunner:
return cmd_flag is None or test_flag == cmd_flag
def _check_operator_first_char(self, test_flag, cmd_flag):
if cmd_flag is None or test_flag[:1].lower() in cmd_flag:
return True
return False
return cmd_flag is None or test_flag[:1].lower() in cmd_flag
def _check_keep_list(self, test_flag, cmd_flag_list):
if cmd_flag_list is None or any(
return cmd_flag_list is None or any(
test_flag == cmd_flag for cmd_flag in cmd_flag_list
):
return True
return False
)
def _keep_test(self, test_case):
# TODO: consider regex matching for test filtering.
@ -362,7 +358,7 @@ class BenchmarkRunner:
)
# Filter framework, operator, test_name, tag, forward_only
if (
return (
self._check_keep(op_test_config.test_name, self.args.test_name)
and self._check_keep_list(test_case.op_bench.module_name(), operators)
and self._check_operator_first_char(
@ -381,10 +377,7 @@ class BenchmarkRunner:
or "device" not in test_case.test_config.input_config
or self.args.device in op_test_config.test_name
)
):
return True
return False
)
def _print_test_case_info(self, test_case):
# Print out the test name and skip the real execution

View File

@ -43,10 +43,7 @@ def find(testcase, condition):
def skipped_test(testcase):
def condition(children):
tags = [child.tag for child in children]
if "skipped" in tags:
return True
return False
return "skipped" in {child.tag for child in children}
return find(testcase, condition)
@ -55,12 +52,8 @@ def passed_test(testcase):
def condition(children):
if len(children) == 0:
return True
tags = [child.tag for child in children]
if "skipped" in tags:
return False
if "failed" in tags:
return False
return True
tags = {child.tag for child in children}
return "skipped" not in tags and "failed" not in tags
return find(testcase, condition)

View File

@ -41,13 +41,7 @@ def should_exclude(key):
if test_file == "UNKNOWN":
return True
# Policy: "pass rate" does not include inductor, export, or dynamo tests.
if test_file.startswith("inductor/"):
return True
if test_file.startswith("export/"):
return True
if test_file.startswith("dynamo/"):
return True
return False
return test_file.startswith(("inductor/", "export/", "dynamo/"))
def compute_pass_rate(eager_dir, dynamo_dir):

View File

@ -39,9 +39,7 @@ requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
def check_dynamic_shape_capture():
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
if not config.assume_static_by_default:
return True
return False
return not config.assume_static_by_default
def count_ops(gm, args, freq, op):

View File

@ -56,9 +56,7 @@ ARGS = (torch.ones(1000, 1000, requires_grad=True),)
class StructuredTraceTestingFilter(logging.Filter):
def filter(self, record):
if "str" in record.metadata:
return False
return True
return "str" not in record.metadata
class StructuredTraceTestingFormatter(logging.Formatter):

View File

@ -844,9 +844,7 @@ class TestSubgraphRewriter(JitTestCase):
if input_idx == 1:
num_node = node
input_idx += 1
if not isinstance(match.nodes_map[num_node], (int, float)):
return False
return True
return isinstance(match.nodes_map[num_node], (int, float))
def check_replacement_nodes(self, traced, matches):
replacement_nodes_in_graph = [

View File

@ -114,11 +114,10 @@ class TestDataParallel(JitTestCase):
def assert_share_data(t1, t2):
# Only checks that they point to the same memory on the same device.
if t1.device != t2.device:
return False
if t1.storage().data_ptr() != t2.storage().data_ptr():
return False
return True
return (
t1.device == t2.device
and t1.storage().data_ptr() == t2.storage().data_ptr()
)
for p1, p2 in zip(module.parameters(), replica[0].parameters()):
self.assertTrue(assert_share_data(p1, p2))

View File

@ -104,9 +104,7 @@ def maybe_set_hip_visible_devies():
def strtobool(s):
if s.lower() in ["", "0", "false", "off"]:
return False
return True
return s.lower() not in {"", "0", "false", "off"}
class TestChoices(list):

View File

@ -167,7 +167,7 @@ def doAutodiffCheck(testname):
# these tests are disabled because BailOut nodes
# inserted by ProfilingExecutor interfere with
# subgraph slicing of Differentiable Graphs
test_exceptions = [
test_exceptions = (
# functional
'test_nn_dropout',
'test_nn_log_softmax',
@ -195,11 +195,9 @@ def doAutodiffCheck(testname):
'test_split_with_sizes_dim_neg0',
'test_split_with_sizes_size_0',
'test_nn_max_pool2d_with_indices',
]
)
if testname in test_exceptions:
return False
return True
return testname not in test_exceptions
# TODO: enable TE in PE when all tests are fixed

View File

@ -365,9 +365,7 @@ class TestModule(TestCase):
elif isinstance(obj, dict):
return any(_can_be_noncontiguous(o) for o in obj.values())
# scalar tensors can not be non-contiguous
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
return False
return True
return isinstance(obj, torch.Tensor) and obj.dim() != 0
for module_input in module_inputs:
if module_input.forward_input is None:

View File

@ -113,11 +113,7 @@ def reduction_dtype_filter(op):
or torch.int16 not in op.dtypes
):
return False
argspec = inspect.getfullargspec(op.op)
if "dtype" not in argspec.kwonlyargs:
return False
return True
return "dtype" in inspect.getfullargspec(op.op).kwonlyargs
# Create a list of operators that are a subset of _ref_test_ops but don't have a

View File

@ -149,13 +149,7 @@ class DiagonalTensor:
return cls.handled_functions[func](*args, **kwargs)
def __eq__(self, other):
if type(other) is type(self):
if self._N == other._N and self._i == other._i:
return True
else:
return False
else:
return False
return type(other) is type(self) and self._N == other._N and self._i == other._i
@implements_diagonal(torch.mean)
def mean(mat):

View File

@ -436,10 +436,7 @@ def is_forward_derivative_definition(
all_arg_names: list[str], names: tuple[str, ...]
) -> bool:
for name in names:
if name not in all_arg_names:
return True
else:
return False
return name not in all_arg_names
raise RuntimeError("Expected `names` to be non-empty")

View File

@ -22,9 +22,7 @@ class GcovCoverageParser:
This is repo-specific and only makes sense for the current state of
ovrsource.
"""
if "third-party" in path:
return True
return False
return "third-party" in path
def parse(self) -> list[CoverageRecord]:
# The JSON format is described in the gcov source code

View File

@ -24,9 +24,7 @@ class LlvmCoverageParser:
This is repo-specific and only makes sense for the current state of
ovrsource.
"""
if "/third-party/" in path:
return True
return False
return "/third-party/" in path
@staticmethod
def _collect_coverage(

View File

@ -361,12 +361,11 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs))
compiled_gm = compiler_fn(gm, example_inputs)
run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs))
return False
except Exception as e:
new_failure = str(e)
if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5:
return True
return False
return False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

View File

@ -341,14 +341,12 @@ class VariableBuilder:
return vt
def _can_lift_attrs_to_inputs(self, vt):
if type(vt) in [
return type(vt) in {
TensorVariable,
TensorWithTFOverrideVariable,
UserDefinedObjectVariable,
NumpyNdarrayVariable,
]:
return True
return False
}
@staticmethod
@functools.lru_cache(None)

View File

@ -338,8 +338,6 @@ def _tensors_definitely_do_not_overlap(x, y):
# without
if offset_delta_mod + y.size(1) <= x.stride(0):
return True
else:
return False
return False

View File

@ -1824,17 +1824,14 @@ class WrapperCodeGen(CodeGen):
def can_reuse(self, input_buffer, output_buffer=None):
name = input_buffer.get_name()
if (
return not (
name in V.graph.removed_buffers
or name in V.graph.graph_inputs
or name in V.graph.constants
or name in V.graph.torchbind_constants
or name in V.graph.never_reuse_buffers
or name in self.freed
):
return False
return True
)
def did_reuse(self, buffer, reused_buffer):
# Check whether a given buffer was reused by a possible reuser in the wrapper codegen

View File

@ -562,11 +562,7 @@ class BatchLinearLHSFusion(BatchFusion):
def is_node_meta_valid(node: Optional[torch.fx.Node]):
if node is None:
return True
if "example_value" not in node.meta and "val" not in node.meta:
return False
return True
return node is None or "example_value" in node.meta or "val" in node.meta
# Poor person's check for if a node in the graph mutates its input.

View File

@ -607,13 +607,7 @@ def _is_valid_quantized_op_binary_optimization_pattern(
)
)
if (
len(
_get_remaining_users(
extra_input_of_pattern,
compute_node,
)
)
> 1
len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
or extra_input_of_pattern == compute_node.args[0]
):
return False

View File

@ -310,11 +310,9 @@ class SizeVarAllocator:
"""
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
"""
if len(left) != len(right):
return False
if all(self.statically_known_equals(l, r) for l, r in zip(left, right)):
return True
return False
return len(left) == len(right) and all(
self.statically_known_equals(l, r) for l, r in zip(left, right)
)
# See Note - [On Statically Known]
def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:

View File

@ -687,13 +687,7 @@ def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
Validates that perm is a permutation of length rank.
"""
if not isinstance(perm, Sequence):
return False
if not (tuple(sorted(perm)) == tuple(range(0, rank))):
return False
return True
return isinstance(perm, Sequence) and sorted(perm) == list(range(rank))
def is_same_shape(a: Sequence, b: Sequence) -> bool:

View File

@ -354,9 +354,7 @@ aten = torch._ops.ops.aten
def is_noncontiguous_supported(device):
if device is not None and device.type == "hpu":
return False
return True
return device is None or device.type != "hpu"
def handle_noncontiguous_outputs(input_tlist, output):

View File

@ -54,9 +54,7 @@ def ordered_set(*items):
# This function indicates if the backend device
# supports non-contiguous tensors
def is_noncontiguous_supported(device):
if device.type == "hpu":
return False
return True
return device.type != "hpu"
_like_tensor_constructors = ordered_set(

View File

@ -350,9 +350,7 @@ class Replicate(Placement):
"""
def __eq__(self, other: object) -> bool:
if not isinstance(other, Replicate):
return False
return True
return isinstance(other, Replicate)
def __hash__(self) -> int:
# every replicate placement is the same

View File

@ -65,12 +65,10 @@ logger = logging.getLogger(__name__)
def _should_unshard_params(fsdp_state: _FSDPState) -> bool:
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD and (
_is_composable(fsdp_state) or fsdp_state._use_orig_params
):
return False
else:
return True
return not (
fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
and (_is_composable(fsdp_state) or fsdp_state._use_orig_params)
)
def _convert_to_wrapped_module_name(module_name: str) -> str:

View File

@ -110,17 +110,11 @@ class _remote_device:
raise RuntimeError("Invalid state!")
def __eq__(self, other):
if not isinstance(other, _remote_device):
return False
if (
return isinstance(other, _remote_device) and (
self._worker_name == other._worker_name
and self._device == other._device
and self._rank == other._rank
):
return True
return False
)
def __hash__(self):
return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank)

View File

@ -137,9 +137,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable
"""
def depends_on(a: Callable, b: Callable):
if a == that and b == this:
return False
return True
return a != that or b != this
return depends_on

View File

@ -204,10 +204,7 @@ class OpSupports:
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
if node.name in disallow_set:
return False
else:
return True
return node.name not in disallow_set
return create_op_support(_decline_if_node_in_names)

View File

@ -137,9 +137,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable):
"""
def depends_on(a: Callable, b: Callable):
if a == that and b == this:
return False
return True
return a != that or b != this
return depends_on
@ -170,9 +168,7 @@ def these_before_those_pass_constraint(these: Callable, those: Callable):
"""
def depends_on(a: Callable, b: Callable):
if unwrap(a) == those and unwrap(b) == these:
return False
return True
return unwrap(a) != those or unwrap(b) != these
return depends_on

View File

@ -1733,9 +1733,7 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
@_beartype.beartype
def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
if any(use.user in nodes for use in value.uses()):
return True
return False
return any(use.user in nodes for use in value.uses())
@_beartype.beartype

View File

@ -1435,10 +1435,7 @@ def _check_and_build_extension_h_precompiler_headers(
# read all content of a file
content = file.read()
# check if string present in a file
if signature == content:
return True
else:
return False
return signature == content
def _create_if_not_exist(path_dir):
if not os.path.exists(path_dir):

View File

@ -39,13 +39,10 @@ def _get_all_graph_pipes_helper(
def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
if isinstance(datapipe, _ShardingIterDataPipe):
return True
if hasattr(datapipe, "apply_sharding") and inspect.ismethod(
datapipe.apply_sharding
):
return True
return False
return isinstance(datapipe, _ShardingIterDataPipe) or (
hasattr(datapipe, "apply_sharding")
and inspect.ismethod(datapipe.apply_sharding)
)
def apply_sharding(
@ -89,13 +86,12 @@ def apply_sharding(
def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
return False
if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(
datapipe.set_seed
):
return False
return True
return (
hasattr(datapipe, "set_shuffle")
and hasattr(datapipe, "set_seed")
and inspect.ismethod(datapipe.set_shuffle)
and inspect.ismethod(datapipe.set_seed)
)
def apply_shuffle_settings(
@ -143,9 +139,7 @@ def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
def _is_random_datapipe(datapipe: DataPipe) -> bool:
if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
return True
return False
return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed)
def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe: