mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[2/N] More ruff SIM fixes (#165031)"
This reverts commit 38095fbd1323ee4a9541fbcbb9b28bd20f2cd956. Reverted https://github.com/pytorch/pytorch/pull/165031 on behalf of https://github.com/albanD due to One of the changed line started to fail on trunk ([comment](https://github.com/pytorch/pytorch/pull/165031#issuecomment-3390190870))
This commit is contained in:
@ -390,8 +390,8 @@ class DeviceTypeTestBase(TestCase):
|
||||
return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
|
||||
|
||||
def _apply_precision_override_for_test(self, test, param_kwargs):
|
||||
dtype = param_kwargs.get("dtype")
|
||||
dtype = param_kwargs.get("dtypes", dtype)
|
||||
dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
|
||||
dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
|
||||
if dtype:
|
||||
self.precision = self._get_precision_override(test, dtype)
|
||||
self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)
|
||||
|
@ -1915,7 +1915,7 @@ def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
|
||||
for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
|
||||
# The scalar we are passing to new_full must be the same dtype
|
||||
# as the one of the resulting tensor
|
||||
use_dtype = sample.kwargs.get('dtype', dtype)
|
||||
use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
|
||||
yield SampleInput(
|
||||
sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)
|
||||
|
||||
|
@ -725,7 +725,7 @@ class DistributedTest:
|
||||
lines = out.getvalue().splitlines()
|
||||
|
||||
def format_line(var):
|
||||
return f"env:{var}={os.environ.get(var, 'N/A')}"
|
||||
return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
|
||||
|
||||
# Check relevant env vars
|
||||
vars = [
|
||||
@ -6212,7 +6212,7 @@ class DistributedTest:
|
||||
)
|
||||
def test_ddp_logging_data_cpu(self):
|
||||
def parse_env(var):
|
||||
return os.environ.get(var, "N/A")
|
||||
return os.environ[var] if var in os.environ else "N/A"
|
||||
|
||||
dist.set_debug_level(dist.DebugLevel.INFO)
|
||||
_, group_id, _ = self._init_global_test()
|
||||
|
Reference in New Issue
Block a user