Revert "[2/N] More ruff SIM fixes (#165031)"

This reverts commit 38095fbd1323ee4a9541fbcbb9b28bd20f2cd956.

Reverted https://github.com/pytorch/pytorch/pull/165031 on behalf of https://github.com/albanD due to One of the changed line started to fail on trunk ([comment](https://github.com/pytorch/pytorch/pull/165031#issuecomment-3390190870))
This commit is contained in:
PyTorch MergeBot
2025-10-10 13:42:14 +00:00
parent 238dd5517d
commit b8be796a57
53 changed files with 141 additions and 100 deletions

View File

@ -390,8 +390,8 @@ class DeviceTypeTestBase(TestCase):
return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
def _apply_precision_override_for_test(self, test, param_kwargs):
dtype = param_kwargs.get("dtype")
dtype = param_kwargs.get("dtypes", dtype)
dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
if dtype:
self.precision = self._get_precision_override(test, dtype)
self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)

View File

@ -1915,7 +1915,7 @@ def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
# The scalar we are passing to new_full must be the same dtype
# as the one of the resulting tensor
use_dtype = sample.kwargs.get('dtype', dtype)
use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
yield SampleInput(
sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)

View File

@ -725,7 +725,7 @@ class DistributedTest:
lines = out.getvalue().splitlines()
def format_line(var):
return f"env:{var}={os.environ.get(var, 'N/A')}"
return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
# Check relevant env vars
vars = [
@ -6212,7 +6212,7 @@ class DistributedTest:
)
def test_ddp_logging_data_cpu(self):
def parse_env(var):
return os.environ.get(var, "N/A")
return os.environ[var] if var in os.environ else "N/A"
dist.set_debug_level(dist.DebugLevel.INFO)
_, group_id, _ = self._init_global_test()