[BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2025-07-24 17:56:13 +08:00
committed by PyTorch MergeBot
parent 19ce1beb05
commit 775788f93b
38 changed files with 234 additions and 321 deletions

View File

@ -802,7 +802,8 @@ class AddedAttributesTest(JitBackendTestCase):
# Attach bundled inputs which adds several attributes and functions to the model
self.lowered_module = (
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
lowered_module, input # noqa: F821
lowered_module, # noqa: F821
input,
)
)
post_bundled = self.lowered_module(

View File

@ -279,23 +279,23 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict({"submod": self.submod})
def forward(self, inputs):
assert (
self.modulelist[0] is self.submod
), "__getitem__ failing for ModuleList"
assert self.modulelist[0] is self.submod, (
"__getitem__ failing for ModuleList"
)
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
for module in self.modulelist:
assert module is self.submod, "__iter__ failing for ModuleList"
assert (
self.sequential[0] is self.submod
), "__getitem__ failing for Sequential"
assert self.sequential[0] is self.submod, (
"__getitem__ failing for Sequential"
)
assert len(self.sequential) == 1, "__len__ failing for Sequential"
for module in self.sequential:
assert module is self.submod, "__iter__ failing for Sequential"
assert (
self.moduledict["submod"] is self.submod
), "__getitem__ failing for ModuleDict"
assert self.moduledict["submod"] is self.submod, (
"__getitem__ failing for ModuleDict"
)
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
# note: unable to index moduledict with a string variable currently
@ -439,9 +439,9 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict()
def forward(self, inputs):
assert (
"submod" not in self.moduledict
), "__contains__ fails for ModuleDict"
assert "submod" not in self.moduledict, (
"__contains__ fails for ModuleDict"
)
return inputs
m = MyModule()

View File

@ -405,9 +405,7 @@ class TestPythonBuiltinOP(JitTestCase):
def f():
x = torch.ones(10, 9, 8, 7, 6)
return x{indices}.shape
""".format(
indices=indices
)
""".format(indices=indices)
)
test_str = test_str.replace(r"'", r"")
scope = {}

View File

@ -139,9 +139,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -160,9 +158,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -181,9 +177,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -202,9 +196,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -223,9 +215,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -244,9 +234,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -265,9 +253,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -286,9 +272,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -307,9 +291,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -328,9 +310,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())
@ -351,9 +331,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
"doesn't support instance-level annotations on empty non-base types",
):
torch.jit.script(M())

View File

@ -960,8 +960,9 @@ class TestTracer(JitTestCase):
V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1))
ge = torch.jit.trace(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V(
torch.rand(1), requires_grad=True
a, b = (
V(torch.rand(1), requires_grad=True),
V(torch.rand(1), requires_grad=True),
)
(r,) = ge(a, b)
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)

View File

@ -396,9 +396,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex(
RuntimeError,
"only int, float, "
"complex, Tensor, device and string keys "
"are supported",
"only int, float, complex, Tensor, device and string keys are supported",
):
torch.jit.script(fn)
@ -602,9 +600,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex(
RuntimeError,
"y is set to type str"
" in the true branch and type int "
"in the false branch",
"y is set to type str in the true branch and type int in the false branch",
):
torch.jit.script(fn)
@ -622,9 +618,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex(
RuntimeError,
"previously had type "
"str but is now being assigned to a"
" value of type int",
"previously had type str but is now being assigned to a value of type int",
):
torch.jit.script(fn)
@ -729,8 +723,7 @@ class TestUnion(JitTestCase):
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_empty"],
"there are multiple possible List type "
"candidates in the Union annotation",
"there are multiple possible List type candidates in the Union annotation",
)
self._assert_passes(
@ -902,8 +895,7 @@ class TestUnion(JitTestCase):
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"],
"none of those dict types can hold the "
"types of the given keys and values",
"none of those dict types can hold the types of the given keys and values",
)
# TODO: String frontend does not support tuple unpacking

View File

@ -406,9 +406,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex(
RuntimeError,
"only int, float, "
"complex, Tensor, device and string keys "
"are supported",
"only int, float, complex, Tensor, device and string keys are supported",
):
torch.jit.script(fn)
@ -612,9 +610,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex(
RuntimeError,
"y is set to type str"
" in the true branch and type int "
"in the false branch",
"y is set to type str in the true branch and type int in the false branch",
):
torch.jit.script(fn)
@ -632,9 +628,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex(
RuntimeError,
"previously had type "
"str but is now being assigned to a"
" value of type int",
"previously had type str but is now being assigned to a value of type int",
):
torch.jit.script(fn)
@ -739,8 +733,7 @@ class TestUnion(JitTestCase):
template,
"List[str] | List[torch.Tensor]",
lhs["list_literal_empty"],
"there are multiple possible List type "
"candidates in the Union annotation",
"there are multiple possible List type candidates in the Union annotation",
)
self._assert_passes(
@ -906,8 +899,7 @@ class TestUnion(JitTestCase):
template,
"Dict[str, torch.Tensor] | Dict[str, int]",
lhs["dict_literal_of_mixed"],
"none of those dict types can hold the "
"types of the given keys and values",
"none of those dict types can hold the types of the given keys and values",
)
# TODO: String frontend does not support tuple unpacking

View File

@ -135,12 +135,14 @@ class TestWarn(JitTestCase):
bar()
FileCheck().check_count(
str="UserWarning: I am warning you from foo", count=1, exactly=True
str="UserWarning: I am warning you from foo",
count=1,
exactly=True,
).check_count(
str="UserWarning: I am warning you from bar", count=1, exactly=True
).run(
f.getvalue()
)
str="UserWarning: I am warning you from bar",
count=1,
exactly=True,
).run(f.getvalue())
if __name__ == "__main__":

View File

@ -42,12 +42,12 @@ class LazyGeneratorTest(TestCase):
torch._lazy.mark_step()
assert torch.allclose(
cpu_t1, lazy_t1.to("cpu")
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
assert torch.allclose(
cpu_t2, lazy_t2.to("cpu")
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
assert torch.allclose(cpu_t1, lazy_t1.to("cpu")), (
f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
)
assert torch.allclose(cpu_t2, lazy_t2.to("cpu")), (
f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
)
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
def test_generator_causes_multiple_compiles(self):
@ -69,29 +69,29 @@ class LazyGeneratorTest(TestCase):
torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile")
assert (
uncached_compile == 1
), f"Expected 1 uncached compiles, got {uncached_compile}"
assert uncached_compile == 1, (
f"Expected 1 uncached compiles, got {uncached_compile}"
)
t = generate_tensor(2)
torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile")
assert (
uncached_compile == 2
), f"Expected 2 uncached compiles, got {uncached_compile}"
assert uncached_compile == 2, (
f"Expected 2 uncached compiles, got {uncached_compile}"
)
t = generate_tensor(1) # noqa: F841
torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile")
assert (
uncached_compile == 2
), f"Expected 2 uncached compiles, got {uncached_compile}"
assert uncached_compile == 2, (
f"Expected 2 uncached compiles, got {uncached_compile}"
)
cached_compile = metrics.counter_value("CachedCompile")
assert (
cached_compile == 1
), f"Expected 1 cached compile, got {cached_compile}"
assert cached_compile == 1, (
f"Expected 1 cached compile, got {cached_compile}"
)
metrics.reset()

View File

@ -486,17 +486,9 @@ class TestLiteScriptModule(TestCase):
"Traceback of TorchScript"
).check("self.b.forwardError").check_next(
"~~~~~~~~~~~~~~~~~~~ <--- HERE"
).check(
"return self.call"
).check_next(
"~~~~~~~~~ <--- HERE"
).check(
).check("return self.call").check_next("~~~~~~~~~ <--- HERE").check(
"return torch.ones"
).check_next(
"~~~~~~~~~~ <--- HERE"
).run(
str(exp)
)
).check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):

View File

@ -25,7 +25,7 @@ class TestNnModule(torch.nn.Module):
torch.nn.ReLU(True),
# state size. (ngf) x 32 x 32
torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
torch.nn.Tanh()
torch.nn.Tanh(),
# state size. (nc) x 64 x 64
)

View File

@ -721,9 +721,9 @@ class TestExecutionTrace(TestCase):
with tempfile.TemporaryDirectory() as temp_dir:
fp_name = os.path.join(temp_dir, "test.et.json")
os.environ[
"ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"
] = "aten::gather"
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = (
"aten::gather"
)
et = ExecutionTraceObserver()
et.register_callback(fp_name)
et.set_extra_resource_collection(True)

View File

@ -1468,7 +1468,7 @@ class TestProfiler(TestCase):
cats = {e.get("cat", None) for e in j["traceEvents"]}
self.assertTrue(
"cuda_sync" in cats,
"Expected to find cuda_sync event" f" found = {cats}",
f"Expected to find cuda_sync event found = {cats}",
)
print("Testing enable_cuda_sync_events in _ExperimentalConfig")

View File

@ -47,6 +47,6 @@ class AOMigrationTestCase(TestCase):
new_dict = getattr(new_location, dict_name)
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
for key in new_dict.keys():
assert (
old_dict[key] == new_dict[key]
), f"Dicts don't match: {dict_name} for key {key}"
assert old_dict[key] == new_dict[key], (
f"Dicts don't match: {dict_name} for key {key}"
)

View File

@ -426,7 +426,6 @@ instantiate_device_type_tests(TestFloat4Dtype, globals())
class TestFloat8DtypeCPUOnly(TestCase):
"""
Test of mul implementation

View File

@ -248,9 +248,9 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
+ cls._FLOAT_MODULE.__name__
)
if not qconfig:
assert hasattr(
mod, "qconfig"
), "Input float module must have qconfig defined"
assert hasattr(mod, "qconfig"), (
"Input float module must have qconfig defined"
)
assert mod.qconfig, "Input float module must have a valid qconfig"
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]

View File

@ -99,17 +99,17 @@ class OnDevicePTQUtils:
):
raise ValueError("Quantized weight must be produced.")
fp_weight = weight.inputsAt(0).node()
assert (
fp_weight.kind() == "prim::GetAttr"
), "Weight must be an attribute of the module."
assert fp_weight.kind() == "prim::GetAttr", (
"Weight must be an attribute of the module."
)
fp_weight_name = fp_weight.s("name")
return fp_weight_name
@staticmethod
def is_per_channel_quantized_packed_param(node):
assert (
node.kind() == "quantized::linear_prepack"
), "Node must corresponds to linear_prepack."
assert node.kind() == "quantized::linear_prepack", (
"Node must corresponds to linear_prepack."
)
weight = node.inputsAt(0).node()
assert (
weight.kind() != "aten::quantize_per_tensor"

View File

@ -124,11 +124,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
"aten::dequantize"
).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next(
"aten::conv2d"
).check_next(
"aten::quantize_per_tensor"
).check_next(
"aten::dequantize"
).run(
).check_next("aten::quantize_per_tensor").check_next("aten::dequantize").run(
freezed.graph
)
@ -670,9 +666,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
}
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
assert next(iter(activation_dtypes)) != next(
iter(weight_dtypes)
), "Expected activation dtype to "
assert next(iter(activation_dtypes)) != next(iter(weight_dtypes)), (
"Expected activation dtype to "
)
" be different from wegiht dtype"
def test_insert_observers_for_reused_weight(self):
@ -706,9 +702,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
assert (
conv1_observers == conv2_observers
), "Expect conv1 and conv2 to have same observers since the class type is shared"
assert conv1_observers == conv2_observers, (
"Expect conv1 and conv2 to have same observers since the class type is shared"
)
def test_insert_observers_for_general_ops(self):
"""Make sure we skip observers for ops that doesn't require
@ -734,13 +730,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
'prim::GetAttr[name="conv"]'
).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check(
"aten::flatten"
).check_not(
).check("aten::flatten").check_not(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
).run(m.graph)
# TODO: this is too long, split this to test_insert_observers.py and remove
# insrt_observers prefix
@ -770,17 +762,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
'prim::GetAttr[name="conv1"]'
).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check(
"aten::flatten"
).check_not(
).check("aten::flatten").check_not(
'Observer = prim::GetAttr[name="_observer_'
).check(
'prim::GetAttr[name="conv2"]'
).check(
).check('prim::GetAttr[name="conv2"]').check(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
).run(m.graph)
def test_insert_observers_propagate_observed_in_submodule(self):
"""Make sure we propagate observed property through general ops"""
@ -809,17 +795,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
'prim::GetAttr[name="conv1"]'
).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check(
"prim::CallMethod"
).check_not(
).check("prim::CallMethod").check_not(
'Observer = prim::GetAttr[name="_observer_'
).check(
'prim::GetAttr[name="conv2"]'
).check(
).check('prim::GetAttr[name="conv2"]').check(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
).run(m.graph)
def test_insert_observers_propagate_observed_for_function(self):
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
@ -1055,9 +1035,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
m(data)
m = convert_jit(m, debug=True)
assert (
len(m._modules._c.items()) == 1
), "Expected to have single submodule of conv"
assert len(m._modules._c.items()) == 1, (
"Expected to have single submodule of conv"
)
# make sure the quantized model is executable
m(data)
quant_func = (
@ -1088,17 +1068,17 @@ class TestQuantizeJitPasses(QuantizationTestCase):
qconfig_dict = {"": qconfig}
m = prepare_jit(m, qconfig_dict)
# observers for input, output and value between conv1/conv2
assert (
len(attrs_with_prefix(m, "_observer_")) == 3
), "Expected to have 3 obervers"
assert len(attrs_with_prefix(m, "_observer_")) == 3, (
"Expected to have 3 obervers"
)
# observer for weight
assert (
len(attrs_with_prefix(m.conv1, "_observer_")) == 1
), "Expected to have 1 obervers"
assert len(attrs_with_prefix(m.conv1, "_observer_")) == 1, (
"Expected to have 1 obervers"
)
# observer for weight
assert (
len(attrs_with_prefix(m.conv2, "_observer_")) == 1
), "Expected to have 1 obervers"
assert len(attrs_with_prefix(m.conv2, "_observer_")) == 1, (
"Expected to have 1 obervers"
)
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
m(data)
@ -1107,15 +1087,15 @@ class TestQuantizeJitPasses(QuantizationTestCase):
assert m.conv1._c._type() == m.conv2._c._type()
# check all observers have been removed
assert (
len(attrs_with_prefix(m, "_observer_")) == 0
), "Expected to have 0 obervers"
assert (
len(attrs_with_prefix(m.conv1, "_observer_")) == 0
), "Expected to have 0 obervers"
assert (
len(attrs_with_prefix(m.conv2, "_observer_")) == 0
), "Expected to have 0 obervers"
assert len(attrs_with_prefix(m, "_observer_")) == 0, (
"Expected to have 0 obervers"
)
assert len(attrs_with_prefix(m.conv1, "_observer_")) == 0, (
"Expected to have 0 obervers"
)
assert len(attrs_with_prefix(m.conv2, "_observer_")) == 0, (
"Expected to have 0 obervers"
)
quant_func = (
"aten::quantize_per_channel"
@ -1334,11 +1314,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
"aten::avg_pool2d"
).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
"prim::dtype"
).check_next(
"aten::quantize_per_tensor"
).check(
"aten::dequantize"
).run(
).check_next("aten::quantize_per_tensor").check("aten::dequantize").run(
model.graph
)
@ -1757,9 +1733,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu"
).check_not(f"quantized::conv{dim}d(").check_not(
"quantized::relu("
).run(
m.graph
)
).run(m.graph)
@skipIfNoFBGEMM
def test_quantized_add_alpha(self):
@ -1910,9 +1884,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu("
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
"quantized::relu("
).run(
m.graph
)
).run(m.graph)
@skipIfNoFBGEMM
def test_quantized_add(self):
@ -2119,9 +2091,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu("
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
"quantized::relu("
).run(
m.graph
)
).run(m.graph)
@skipIfNoFBGEMM
def test_quantized_add_scalar_relu(self):
@ -2205,11 +2175,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu("
).check_not("aten::relu_(").check_not(
"quantized::add_scalar("
).check_not(
"quantized::relu("
).run(
m.graph
)
).check_not("quantized::relu(").run(m.graph)
@skipIfNoFBGEMM
def test_quantized_cat(self):
@ -2544,9 +2510,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu("
).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
"quantized::relu("
).run(
m.graph
)
).run(m.graph)
@skipIfNoFBGEMM
def test_quantized_mul_scalar_relu(self):
@ -2629,11 +2593,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu("
).check_not("aten::relu_(").check_not(
"quantized::mul_scalar("
).check_not(
"quantized::relu("
).run(
m.graph
)
).check_not("quantized::relu(").run(m.graph)
@override_qengines
def test_hardswish(self):
@ -3103,9 +3063,7 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
'Observer = prim::GetAttr[name="_observer_'
).check("prim::CallMethod").check_not(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
).run(m.graph)
def test_insert_quant_dequant_linear_dynamic(self):
class M(torch.nn.Module):
@ -3126,9 +3084,9 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
else default_dynamic_qconfig
)
m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
assert (
len(m._modules._c.items()) == 2
), "Expected to have two submodule of linear"
assert len(m._modules._c.items()) == 2, (
"Expected to have two submodule of linear"
)
wt_quant_func = (
"aten::quantize_per_channel"
@ -3141,21 +3099,11 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
act_quant_func
).check_next("aten::dequantize").check(
"aten::_choose_qparams_per_tensor"
).check_next(
act_quant_func
).check_next(
"aten::dequantize"
).check(
).check_next(act_quant_func).check_next("aten::dequantize").check(
wt_quant_func
).check_next(
"aten::dequantize"
).check_not(
wt_quant_func
).check(
).check_next("aten::dequantize").check_not(wt_quant_func).check(
"return"
).run(
m.graph
)
).run(m.graph)
@override_qengines
def test_dynamic_multi_op(self):

View File

@ -254,16 +254,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
maxpool_node = node
input_act = maxpool_node.args[0]
assert isinstance(input_act, Node)
maxpool_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
},
output_qspec=SharedQuantizationSpec(
(input_act, maxpool_node)
),
_annotated=True,
maxpool_node.meta["quantization_annotation"] = (
QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
},
output_qspec=SharedQuantizationSpec(
(input_act, maxpool_node)
),
_annotated=True,
)
)
def validate(self, model: torch.fx.GraphModule) -> None:
@ -339,9 +339,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[Tensor, Tensor]:
assert (
len(obs_or_fqs) == 2
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
assert len(obs_or_fqs) == 2, (
f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
)
act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
@ -442,9 +442,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[Tensor, Tensor]:
assert (
len(obs_or_fqs) == 1
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
assert len(obs_or_fqs) == 1, (
f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
)
weight_obs_or_fq = obs_or_fqs[0]
(
weight_scale,
@ -748,16 +748,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
(first_input_node, cat_node)
)
for input_node in input_nodes[1:]:
input_qspec_map[
input_node
] = share_qparams_with_input_act0_qspec
input_qspec_map[input_node] = (
share_qparams_with_input_act0_qspec
)
cat_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
cat_node.meta["quantization_annotation"] = (
QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
)
def validate(self, model: torch.fx.GraphModule) -> None:
@ -783,9 +783,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
obs_ins0 = getattr(m, input0.target)
obs_ins1 = getattr(m, input1.target)
assert obs_ins0 == obs_ins1
assert (
len(conv_output_obs) == 2
), "expecting two observer that follows conv2d ops"
assert len(conv_output_obs) == 2, (
"expecting two observer that follows conv2d ops"
)
# checking that the output observers for the two convs are shared as well
assert conv_output_obs[0] == conv_output_obs[1]
@ -850,9 +850,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
obs_ins2 = getattr(m, output_obs.target)
assert obs_ins0 == obs_ins2, "input observer does not match output"
assert (
len(conv_output_obs) == 2
), "expecting two observer that follows conv2d ops"
assert len(conv_output_obs) == 2, (
"expecting two observer that follows conv2d ops"
)
# checking that the output observers for the two convs are shared as well
assert conv_output_obs[0] == conv_output_obs[1]
@ -967,16 +967,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
(first_input_node, cat_node)
)
for input_node in input_nodes[1:]:
input_qspec_map[
input_node
] = share_qparams_with_input_act0_qspec
input_qspec_map[input_node] = (
share_qparams_with_input_act0_qspec
)
cat_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
cat_node.meta["quantization_annotation"] = (
QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
)
def validate(self, model: torch.fx.GraphModule) -> None:
@ -1063,16 +1063,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
(second_input_node, cat_node)
)
input_qspec_map[
first_input_node
] = share_qparams_with_input_act1_qspec
input_qspec_map[first_input_node] = (
share_qparams_with_input_act1_qspec
)
cat_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act1_qspec,
_annotated=True,
cat_node.meta["quantization_annotation"] = (
QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act1_qspec,
_annotated=True,
)
)
def validate(self, model: torch.fx.GraphModule) -> None:
@ -1121,17 +1121,17 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
(second_input_node, add_node)
)
input_qspec_map[
first_input_node
] = share_qparams_with_input_act1_qspec
input_qspec_map[first_input_node] = (
share_qparams_with_input_act1_qspec
)
add_node.meta[
"quantization_annotation"
] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act1_qspec,
allow_implicit_sharing=False,
_annotated=True,
add_node.meta["quantization_annotation"] = (
QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act1_qspec,
allow_implicit_sharing=False,
_annotated=True,
)
)
def validate(self, model: torch.fx.GraphModule) -> None:

View File

@ -277,9 +277,9 @@ class PT2EQATTestCase(QuantizationTestCase):
# Verify: conv literal args
if expected_conv_literal_args is not None:
assert (
len(expected_conv_literal_args) == 6
), "wrong num conv args, bad test setup"
assert len(expected_conv_literal_args) == 6, (
"wrong num conv args, bad test setup"
)
for i in range(6):
if i + 3 < len(conv_node.args):
self.assertEqual(

View File

@ -157,9 +157,9 @@ async def run1(coroutine_id):
gpuid = coroutine_id % GPUS
else:
gpu_assignments = args.gpus.split(":")
assert args.nproc == len(
gpu_assignments
), "Please specify GPU assignment for each process, separated by :"
assert args.nproc == len(gpu_assignments), (
"Please specify GPU assignment for each process, separated by :"
)
gpuid = gpu_assignments[coroutine_id]
while progress < len(ALL_TESTS):

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"]
""" Test functions for limits module.
"""Test functions for limits module."""
"""
import functools
import warnings
from unittest import expectedFailure as xfail, skipIf

View File

@ -4104,6 +4104,7 @@ class TestIO(TestCase):
def test_decimal_period_separator():
pass
def test_decimal_comma_separator():
with CommaDecimalPointLocale():
pass
@ -6786,7 +6787,10 @@ class TestWritebackIfCopy(TestCase):
class TestArange(TestCase):
def test_infinite(self):
assert_raises(
(RuntimeError, ValueError), np.arange, 0, np.inf # "unsupported range",
(RuntimeError, ValueError),
np.arange,
0,
np.inf, # "unsupported range",
)
def test_nan_step(self):

View File

@ -2733,10 +2733,18 @@ class TestMoveaxis(TestCase):
assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds',
assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds',
assert_raises(
np.AxisError, np.moveaxis, x, 0, 5 # 'destination.*out of bounds',
np.AxisError,
np.moveaxis,
x,
0,
5, # 'destination.*out of bounds',
)
assert_raises(
ValueError, np.moveaxis, x, [0, 0], [0, 1] # 'repeated axis in `source`',
ValueError,
np.moveaxis,
x,
[0, 0],
[0, 1], # 'repeated axis in `source`',
)
assert_raises(
ValueError, # 'repeated axis in `destination`',

View File

@ -3,6 +3,7 @@
"""
Test the scalar constructors, which also do type-coercion
"""
import functools
from unittest import skipIf as skipif

View File

@ -3,6 +3,7 @@
"""
Test the scalar constructors, which also do type-coercion
"""
import fractions
import functools
import types

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"]
""" Test printing of scalar types.
"""Test printing of scalar types."""
"""
import functools
from unittest import skipIf as skipif

View File

@ -811,7 +811,10 @@ class TestBlock(TestCase):
assert_raises_regex(ValueError, msg, block, [[1], 2])
assert_raises_regex(ValueError, msg, block, [[], 2])
assert_raises_regex(
ValueError, msg, block, [[[1], [2]], [[3, 4]], [5]] # missing brackets
ValueError,
msg,
block,
[[[1], [2]], [[3, 4]], [5]], # missing brackets
)
def test_empty_lists(self, block):

View File

@ -5,6 +5,7 @@
Copied from fftpack.helper by Pearu Peterson, October 2005
"""
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_TORCHDYNAMO,

View File

@ -372,7 +372,7 @@ class TestFFTThreadSafe(TestCase):
assert_allclose(
q.get(timeout=5),
expected,
atol=2e-14
atol=2e-14,
# msg="Function returned wrong value in multithreaded context",
)

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"]
"""Test functions for 1D array set operations.
"""Test functions for 1D array set operations."""
"""
from unittest import expectedFailure as xfail, skipIf
import numpy

View File

@ -2881,7 +2881,8 @@ class TestPercentile(TestCase):
np.testing.assert_equal(res.dtype, arr.dtype)
H_F_TYPE_CODES = [
(int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"]
(int_type, np.float64)
for int_type in "Bbhil" # np.typecodes["AllInteger"]
] + [
(np.float16, np.float16),
(np.float32, np.float32),

View File

@ -505,8 +505,7 @@ class TestHistogramOptimBinNums(TestCase):
assert_equal(
len(a),
numbins,
err_msg=f"For the {estimator} estimator "
f"with datasize of {testlen}",
err_msg=f"For the {estimator} estimator with datasize of {testlen}",
)
def test_small(self):
@ -552,8 +551,7 @@ class TestHistogramOptimBinNums(TestCase):
assert_equal(
len(a),
expbins,
err_msg=f"For the {estimator} estimator "
f"with datasize of {testlen}",
err_msg=f"For the {estimator} estimator with datasize of {testlen}",
)
def test_incorrect_methods(self):

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"]
"""Test functions for matrix module
"""Test functions for matrix module"""
"""
import functools
from unittest import expectedFailure as xfail, skipIf as skipif

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"]
""" Test functions for linalg module
"""Test functions for linalg module"""
"""
import functools
import itertools
import os

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
"""Light smoke test switching between numpy to pytorch random streams.
"""
"""Light smoke test switching between numpy to pytorch random streams."""
from contextlib import contextmanager
from functools import partial

View File

@ -9,6 +9,7 @@ The goal is to validate on numpy, and tests should work when replacing
by
>>> import torch._numpy as np
"""
import operator
from unittest import skipIf as skip, SkipTest

View File

@ -39,12 +39,9 @@ USE_BLACK_FILELIST = re.compile(
# test/**
# test/[a-h]*/**
# test/[i-j]*/**
"test/j*/**",
# test/[k-m]*/**
"test/[k-m]*/**",
# test/optim/**
# "test/[p-z]*/**",
"test/[p-z]*/**",
# test/[p-z]*/**,
# torch/**
# torch/_[a-c]*/**
# torch/_[e-h]*/**