mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] fix ruff rule needless-bool (SIM103) (#130206)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130206 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa5f572748
commit
4d7bf72d93
@ -99,10 +99,7 @@ def check_loss(ref_loss, res_loss):
|
||||
assert len(ref_loss) == len(res_loss)
|
||||
length = len(ref_loss)
|
||||
x = min(length, 10)
|
||||
if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 0.1
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
@ -339,16 +339,12 @@ class BenchmarkRunner:
|
||||
return cmd_flag is None or test_flag == cmd_flag
|
||||
|
||||
def _check_operator_first_char(self, test_flag, cmd_flag):
|
||||
if cmd_flag is None or test_flag[:1].lower() in cmd_flag:
|
||||
return True
|
||||
return False
|
||||
return cmd_flag is None or test_flag[:1].lower() in cmd_flag
|
||||
|
||||
def _check_keep_list(self, test_flag, cmd_flag_list):
|
||||
if cmd_flag_list is None or any(
|
||||
return cmd_flag_list is None or any(
|
||||
test_flag == cmd_flag for cmd_flag in cmd_flag_list
|
||||
):
|
||||
return True
|
||||
return False
|
||||
)
|
||||
|
||||
def _keep_test(self, test_case):
|
||||
# TODO: consider regex matching for test filtering.
|
||||
@ -362,7 +358,7 @@ class BenchmarkRunner:
|
||||
)
|
||||
|
||||
# Filter framework, operator, test_name, tag, forward_only
|
||||
if (
|
||||
return (
|
||||
self._check_keep(op_test_config.test_name, self.args.test_name)
|
||||
and self._check_keep_list(test_case.op_bench.module_name(), operators)
|
||||
and self._check_operator_first_char(
|
||||
@ -381,10 +377,7 @@ class BenchmarkRunner:
|
||||
or "device" not in test_case.test_config.input_config
|
||||
or self.args.device in op_test_config.test_name
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
)
|
||||
|
||||
def _print_test_case_info(self, test_case):
|
||||
# Print out the test name and skip the real execution
|
||||
|
||||
@ -43,10 +43,7 @@ def find(testcase, condition):
|
||||
|
||||
def skipped_test(testcase):
|
||||
def condition(children):
|
||||
tags = [child.tag for child in children]
|
||||
if "skipped" in tags:
|
||||
return True
|
||||
return False
|
||||
return "skipped" in {child.tag for child in children}
|
||||
|
||||
return find(testcase, condition)
|
||||
|
||||
@ -55,12 +52,8 @@ def passed_test(testcase):
|
||||
def condition(children):
|
||||
if len(children) == 0:
|
||||
return True
|
||||
tags = [child.tag for child in children]
|
||||
if "skipped" in tags:
|
||||
return False
|
||||
if "failed" in tags:
|
||||
return False
|
||||
return True
|
||||
tags = {child.tag for child in children}
|
||||
return "skipped" not in tags and "failed" not in tags
|
||||
|
||||
return find(testcase, condition)
|
||||
|
||||
|
||||
@ -41,13 +41,7 @@ def should_exclude(key):
|
||||
if test_file == "UNKNOWN":
|
||||
return True
|
||||
# Policy: "pass rate" does not include inductor, export, or dynamo tests.
|
||||
if test_file.startswith("inductor/"):
|
||||
return True
|
||||
if test_file.startswith("export/"):
|
||||
return True
|
||||
if test_file.startswith("dynamo/"):
|
||||
return True
|
||||
return False
|
||||
return test_file.startswith(("inductor/", "export/", "dynamo/"))
|
||||
|
||||
|
||||
def compute_pass_rate(eager_dir, dynamo_dir):
|
||||
|
||||
@ -39,9 +39,7 @@ requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
def check_dynamic_shape_capture():
|
||||
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
|
||||
if not config.assume_static_by_default:
|
||||
return True
|
||||
return False
|
||||
return not config.assume_static_by_default
|
||||
|
||||
|
||||
def count_ops(gm, args, freq, op):
|
||||
|
||||
@ -56,9 +56,7 @@ ARGS = (torch.ones(1000, 1000, requires_grad=True),)
|
||||
|
||||
class StructuredTraceTestingFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
if "str" in record.metadata:
|
||||
return False
|
||||
return True
|
||||
return "str" not in record.metadata
|
||||
|
||||
|
||||
class StructuredTraceTestingFormatter(logging.Formatter):
|
||||
|
||||
@ -844,9 +844,7 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
if input_idx == 1:
|
||||
num_node = node
|
||||
input_idx += 1
|
||||
if not isinstance(match.nodes_map[num_node], (int, float)):
|
||||
return False
|
||||
return True
|
||||
return isinstance(match.nodes_map[num_node], (int, float))
|
||||
|
||||
def check_replacement_nodes(self, traced, matches):
|
||||
replacement_nodes_in_graph = [
|
||||
|
||||
@ -114,11 +114,10 @@ class TestDataParallel(JitTestCase):
|
||||
|
||||
def assert_share_data(t1, t2):
|
||||
# Only checks that they point to the same memory on the same device.
|
||||
if t1.device != t2.device:
|
||||
return False
|
||||
if t1.storage().data_ptr() != t2.storage().data_ptr():
|
||||
return False
|
||||
return True
|
||||
return (
|
||||
t1.device == t2.device
|
||||
and t1.storage().data_ptr() == t2.storage().data_ptr()
|
||||
)
|
||||
|
||||
for p1, p2 in zip(module.parameters(), replica[0].parameters()):
|
||||
self.assertTrue(assert_share_data(p1, p2))
|
||||
|
||||
@ -104,9 +104,7 @@ def maybe_set_hip_visible_devies():
|
||||
|
||||
|
||||
def strtobool(s):
|
||||
if s.lower() in ["", "0", "false", "off"]:
|
||||
return False
|
||||
return True
|
||||
return s.lower() not in {"", "0", "false", "off"}
|
||||
|
||||
|
||||
class TestChoices(list):
|
||||
|
||||
@ -167,7 +167,7 @@ def doAutodiffCheck(testname):
|
||||
# these tests are disabled because BailOut nodes
|
||||
# inserted by ProfilingExecutor interfere with
|
||||
# subgraph slicing of Differentiable Graphs
|
||||
test_exceptions = [
|
||||
test_exceptions = (
|
||||
# functional
|
||||
'test_nn_dropout',
|
||||
'test_nn_log_softmax',
|
||||
@ -195,11 +195,9 @@ def doAutodiffCheck(testname):
|
||||
'test_split_with_sizes_dim_neg0',
|
||||
'test_split_with_sizes_size_0',
|
||||
'test_nn_max_pool2d_with_indices',
|
||||
]
|
||||
)
|
||||
|
||||
if testname in test_exceptions:
|
||||
return False
|
||||
return True
|
||||
return testname not in test_exceptions
|
||||
|
||||
|
||||
# TODO: enable TE in PE when all tests are fixed
|
||||
|
||||
@ -365,9 +365,7 @@ class TestModule(TestCase):
|
||||
elif isinstance(obj, dict):
|
||||
return any(_can_be_noncontiguous(o) for o in obj.values())
|
||||
# scalar tensors can not be non-contiguous
|
||||
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
|
||||
return False
|
||||
return True
|
||||
return isinstance(obj, torch.Tensor) and obj.dim() != 0
|
||||
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
|
||||
@ -113,11 +113,7 @@ def reduction_dtype_filter(op):
|
||||
or torch.int16 not in op.dtypes
|
||||
):
|
||||
return False
|
||||
|
||||
argspec = inspect.getfullargspec(op.op)
|
||||
if "dtype" not in argspec.kwonlyargs:
|
||||
return False
|
||||
return True
|
||||
return "dtype" in inspect.getfullargspec(op.op).kwonlyargs
|
||||
|
||||
|
||||
# Create a list of operators that are a subset of _ref_test_ops but don't have a
|
||||
|
||||
@ -149,13 +149,7 @@ class DiagonalTensor:
|
||||
return cls.handled_functions[func](*args, **kwargs)
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) is type(self):
|
||||
if self._N == other._N and self._i == other._i:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return type(other) is type(self) and self._N == other._N and self._i == other._i
|
||||
|
||||
@implements_diagonal(torch.mean)
|
||||
def mean(mat):
|
||||
|
||||
@ -436,10 +436,7 @@ def is_forward_derivative_definition(
|
||||
all_arg_names: list[str], names: tuple[str, ...]
|
||||
) -> bool:
|
||||
for name in names:
|
||||
if name not in all_arg_names:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return name not in all_arg_names
|
||||
raise RuntimeError("Expected `names` to be non-empty")
|
||||
|
||||
|
||||
|
||||
@ -22,9 +22,7 @@ class GcovCoverageParser:
|
||||
This is repo-specific and only makes sense for the current state of
|
||||
ovrsource.
|
||||
"""
|
||||
if "third-party" in path:
|
||||
return True
|
||||
return False
|
||||
return "third-party" in path
|
||||
|
||||
def parse(self) -> list[CoverageRecord]:
|
||||
# The JSON format is described in the gcov source code
|
||||
|
||||
@ -24,9 +24,7 @@ class LlvmCoverageParser:
|
||||
This is repo-specific and only makes sense for the current state of
|
||||
ovrsource.
|
||||
"""
|
||||
if "/third-party/" in path:
|
||||
return True
|
||||
return False
|
||||
return "/third-party/" in path
|
||||
|
||||
@staticmethod
|
||||
def _collect_coverage(
|
||||
|
||||
@ -361,12 +361,11 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
|
||||
run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs))
|
||||
compiled_gm = compiler_fn(gm, example_inputs)
|
||||
run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs))
|
||||
return False
|
||||
except Exception as e:
|
||||
new_failure = str(e)
|
||||
if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
@ -341,14 +341,12 @@ class VariableBuilder:
|
||||
return vt
|
||||
|
||||
def _can_lift_attrs_to_inputs(self, vt):
|
||||
if type(vt) in [
|
||||
return type(vt) in {
|
||||
TensorVariable,
|
||||
TensorWithTFOverrideVariable,
|
||||
UserDefinedObjectVariable,
|
||||
NumpyNdarrayVariable,
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
|
||||
@ -338,8 +338,6 @@ def _tensors_definitely_do_not_overlap(x, y):
|
||||
# without
|
||||
if offset_delta_mod + y.size(1) <= x.stride(0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -1824,17 +1824,14 @@ class WrapperCodeGen(CodeGen):
|
||||
|
||||
def can_reuse(self, input_buffer, output_buffer=None):
|
||||
name = input_buffer.get_name()
|
||||
if (
|
||||
return not (
|
||||
name in V.graph.removed_buffers
|
||||
or name in V.graph.graph_inputs
|
||||
or name in V.graph.constants
|
||||
or name in V.graph.torchbind_constants
|
||||
or name in V.graph.never_reuse_buffers
|
||||
or name in self.freed
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
)
|
||||
|
||||
def did_reuse(self, buffer, reused_buffer):
|
||||
# Check whether a given buffer was reused by a possible reuser in the wrapper codegen
|
||||
|
||||
@ -562,11 +562,7 @@ class BatchLinearLHSFusion(BatchFusion):
|
||||
|
||||
|
||||
def is_node_meta_valid(node: Optional[torch.fx.Node]):
|
||||
if node is None:
|
||||
return True
|
||||
if "example_value" not in node.meta and "val" not in node.meta:
|
||||
return False
|
||||
return True
|
||||
return node is None or "example_value" in node.meta or "val" in node.meta
|
||||
|
||||
|
||||
# Poor person's check for if a node in the graph mutates its input.
|
||||
|
||||
@ -607,13 +607,7 @@ def _is_valid_quantized_op_binary_optimization_pattern(
|
||||
)
|
||||
)
|
||||
if (
|
||||
len(
|
||||
_get_remaining_users(
|
||||
extra_input_of_pattern,
|
||||
compute_node,
|
||||
)
|
||||
)
|
||||
> 1
|
||||
len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
|
||||
or extra_input_of_pattern == compute_node.args[0]
|
||||
):
|
||||
return False
|
||||
|
||||
@ -310,11 +310,9 @@ class SizeVarAllocator:
|
||||
"""
|
||||
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
|
||||
"""
|
||||
if len(left) != len(right):
|
||||
return False
|
||||
if all(self.statically_known_equals(l, r) for l, r in zip(left, right)):
|
||||
return True
|
||||
return False
|
||||
return len(left) == len(right) and all(
|
||||
self.statically_known_equals(l, r) for l, r in zip(left, right)
|
||||
)
|
||||
|
||||
# See Note - [On Statically Known]
|
||||
def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:
|
||||
|
||||
@ -687,13 +687,7 @@ def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
|
||||
Validates that perm is a permutation of length rank.
|
||||
"""
|
||||
|
||||
if not isinstance(perm, Sequence):
|
||||
return False
|
||||
|
||||
if not (tuple(sorted(perm)) == tuple(range(0, rank))):
|
||||
return False
|
||||
|
||||
return True
|
||||
return isinstance(perm, Sequence) and sorted(perm) == list(range(rank))
|
||||
|
||||
|
||||
def is_same_shape(a: Sequence, b: Sequence) -> bool:
|
||||
|
||||
@ -354,9 +354,7 @@ aten = torch._ops.ops.aten
|
||||
|
||||
|
||||
def is_noncontiguous_supported(device):
|
||||
if device is not None and device.type == "hpu":
|
||||
return False
|
||||
return True
|
||||
return device is None or device.type != "hpu"
|
||||
|
||||
|
||||
def handle_noncontiguous_outputs(input_tlist, output):
|
||||
|
||||
@ -54,9 +54,7 @@ def ordered_set(*items):
|
||||
# This function indicates if the backend device
|
||||
# supports non-contiguous tensors
|
||||
def is_noncontiguous_supported(device):
|
||||
if device.type == "hpu":
|
||||
return False
|
||||
return True
|
||||
return device.type != "hpu"
|
||||
|
||||
|
||||
_like_tensor_constructors = ordered_set(
|
||||
|
||||
@ -350,9 +350,7 @@ class Replicate(Placement):
|
||||
"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Replicate):
|
||||
return False
|
||||
return True
|
||||
return isinstance(other, Replicate)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# every replicate placement is the same
|
||||
|
||||
@ -65,12 +65,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _should_unshard_params(fsdp_state: _FSDPState) -> bool:
|
||||
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD and (
|
||||
_is_composable(fsdp_state) or fsdp_state._use_orig_params
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return not (
|
||||
fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
|
||||
and (_is_composable(fsdp_state) or fsdp_state._use_orig_params)
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_wrapped_module_name(module_name: str) -> str:
|
||||
|
||||
@ -110,17 +110,11 @@ class _remote_device:
|
||||
raise RuntimeError("Invalid state!")
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _remote_device):
|
||||
return False
|
||||
|
||||
if (
|
||||
return isinstance(other, _remote_device) and (
|
||||
self._worker_name == other._worker_name
|
||||
and self._device == other._device
|
||||
and self._rank == other._rank
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank)
|
||||
|
||||
@ -137,9 +137,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable
|
||||
"""
|
||||
|
||||
def depends_on(a: Callable, b: Callable):
|
||||
if a == that and b == this:
|
||||
return False
|
||||
return True
|
||||
return a != that or b != this
|
||||
|
||||
return depends_on
|
||||
|
||||
|
||||
@ -204,10 +204,7 @@ class OpSupports:
|
||||
submodules: t.Mapping[str, torch.nn.Module],
|
||||
node: torch.fx.Node,
|
||||
) -> bool:
|
||||
if node.name in disallow_set:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return node.name not in disallow_set
|
||||
return create_op_support(_decline_if_node_in_names)
|
||||
|
||||
|
||||
|
||||
@ -137,9 +137,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable):
|
||||
"""
|
||||
|
||||
def depends_on(a: Callable, b: Callable):
|
||||
if a == that and b == this:
|
||||
return False
|
||||
return True
|
||||
return a != that or b != this
|
||||
|
||||
return depends_on
|
||||
|
||||
@ -170,9 +168,7 @@ def these_before_those_pass_constraint(these: Callable, those: Callable):
|
||||
"""
|
||||
|
||||
def depends_on(a: Callable, b: Callable):
|
||||
if unwrap(a) == those and unwrap(b) == these:
|
||||
return False
|
||||
return True
|
||||
return unwrap(a) != those or unwrap(b) != these
|
||||
|
||||
return depends_on
|
||||
|
||||
|
||||
@ -1733,9 +1733,7 @@ def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
|
||||
|
||||
@_beartype.beartype
|
||||
def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
|
||||
if any(use.user in nodes for use in value.uses()):
|
||||
return True
|
||||
return False
|
||||
return any(use.user in nodes for use in value.uses())
|
||||
|
||||
|
||||
@_beartype.beartype
|
||||
|
||||
@ -1435,10 +1435,7 @@ def _check_and_build_extension_h_precompiler_headers(
|
||||
# read all content of a file
|
||||
content = file.read()
|
||||
# check if string present in a file
|
||||
if signature == content:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return signature == content
|
||||
|
||||
def _create_if_not_exist(path_dir):
|
||||
if not os.path.exists(path_dir):
|
||||
|
||||
@ -39,13 +39,10 @@ def _get_all_graph_pipes_helper(
|
||||
|
||||
|
||||
def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
|
||||
if isinstance(datapipe, _ShardingIterDataPipe):
|
||||
return True
|
||||
if hasattr(datapipe, "apply_sharding") and inspect.ismethod(
|
||||
datapipe.apply_sharding
|
||||
):
|
||||
return True
|
||||
return False
|
||||
return isinstance(datapipe, _ShardingIterDataPipe) or (
|
||||
hasattr(datapipe, "apply_sharding")
|
||||
and inspect.ismethod(datapipe.apply_sharding)
|
||||
)
|
||||
|
||||
|
||||
def apply_sharding(
|
||||
@ -89,13 +86,12 @@ def apply_sharding(
|
||||
|
||||
|
||||
def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
|
||||
if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
|
||||
return False
|
||||
if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(
|
||||
datapipe.set_seed
|
||||
):
|
||||
return False
|
||||
return True
|
||||
return (
|
||||
hasattr(datapipe, "set_shuffle")
|
||||
and hasattr(datapipe, "set_seed")
|
||||
and inspect.ismethod(datapipe.set_shuffle)
|
||||
and inspect.ismethod(datapipe.set_seed)
|
||||
)
|
||||
|
||||
|
||||
def apply_shuffle_settings(
|
||||
@ -143,9 +139,7 @@ def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
|
||||
|
||||
|
||||
def _is_random_datapipe(datapipe: DataPipe) -> bool:
|
||||
if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
|
||||
return True
|
||||
return False
|
||||
return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed)
|
||||
|
||||
|
||||
def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
|
||||
|
||||
Reference in New Issue
Block a user