[BE][PYFMT] migrate PYFMT for torch/[e-n]*/ to ruff format (#144553)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553
Approved by: https://github.com/ezyang
ghstack dependencies: #144551
This commit is contained in:
Xuehai Pan
2025-06-16 14:35:28 +08:00
committed by PyTorch MergeBot
parent 95cb42c45d
commit 2e0e08588e
48 changed files with 548 additions and 497 deletions

View File

@ -60,7 +60,6 @@ USE_BLACK_FILELIST = re.compile(
"torch/[b-c]*/**",
# torch/d*/**
# torch/[e-m]*/**
"torch/[e-m]*/**",
# torch/optim/**
# torch/[p-z]*/**
"torch/[p-z]*/**",

View File

@ -358,22 +358,24 @@ def save(
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
ep = torch.export.export(MyModule(), (torch.randn(5),))
# Save to file
torch.export.save(ep, 'exported_program.pt2')
torch.export.save(ep, "exported_program.pt2")
# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)
# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
extra_files = {"foo.txt": b"bar".decode("utf-8")}
torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
"""
if not isinstance(ep, ExportedProgram):
@ -427,18 +429,18 @@ def load(
import io
# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')
ep = torch.export.load("exported_program.pt2")
# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
with open("exported_program.pt2", "rb") as f:
buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)
# Load with extra files.
extra_files = {'foo.txt': ''} # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
extra_files = {"foo.txt": ""} # values will be replaced with data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])
print(ep(torch.randn(5)))
"""
if isinstance(f, (str, os.PathLike)):
@ -572,24 +574,29 @@ def register_dataclass(
import torch
from dataclasses import dataclass
@dataclass
class InputDataClass:
feature: torch.Tensor
bias: int
@dataclass
class OutputDataClass:
res: torch.Tensor
torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)
class Mod(torch.nn.Module):
def forward(self, x: InputDataClass) -> OutputDataClass:
res = x.feature + x.bias
return OutputDataClass(res=res)
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
print(ep)
"""

View File

@ -43,7 +43,7 @@ def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str])
continue
res += f"""
File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}""" # type: ignore[index]
File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index]
res += f"\n {stack[-1]['loc']}"
return res
@ -327,12 +327,12 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
# We don't want to log all expression_created logs, only
# the ones that are relevant to the
# guards/propagate_real_tensor
self.expression_created_logs[
metadata[key]["result_id"]
] = ExpressionCreatedNode(
metadata[key]["result_id"],
metadata[key].get("argument_ids", []),
record,
self.expression_created_logs[metadata[key]["result_id"]] = (
ExpressionCreatedNode(
metadata[key]["result_id"],
metadata[key].get("argument_ids", []),
record,
)
)
return
@ -374,10 +374,13 @@ def draft_export(
capture_structured_log = CaptureStructuredTrace()
with torch._functorch.config.patch(
fake_tensor_propagate_real_tensors=True,
generate_fake_kernels_from_real_mismatches=True,
), capture_structured_log:
with (
torch._functorch.config.patch(
fake_tensor_propagate_real_tensors=True,
generate_fake_kernels_from_real_mismatches=True,
),
capture_structured_log,
):
try:
new_shapes = None
ep = _export(
@ -424,10 +427,10 @@ def draft_export(
continue
elif log_name == "propagate_real_tensors_provenance":
log_contents[
"occurrences"
] = capture_structured_log.log_record.get_log_count(
(log_name, log_contents)
log_contents["occurrences"] = (
capture_structured_log.log_record.get_log_count(
(log_name, log_contents)
)
)
failure_type = FailureType.DATA_DEPENDENT_ERROR

View File

@ -26,9 +26,9 @@ def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
if user.op == "output":
continue
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Expected getitem node as user for {node}, instead got {user}"
assert user.op == "call_function" and user.target == operator.getitem, (
f"Expected getitem node as user for {node}, instead got {user}"
)
getitem_users.update(list(user.users.keys()))
return getitem_users
@ -63,9 +63,9 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
log.debug("Trying to remove pytrees for module call %s", curr_module_node)
curr_module_users = list(curr_module_node.users.keys())
assert (
len(curr_module_users) == 1
), f"Expected only one user for module node, instead got {list(curr_module_users)}"
assert len(curr_module_users) == 1, (
f"Expected only one user for module node, instead got {list(curr_module_users)}"
)
flatten_node = curr_module_users[0]
assert (
flatten_node.op == "call_function"

View File

@ -268,9 +268,9 @@ def _extract_fake_inputs(gm, args, kwargs):
if detected_fake_mode:
if detected_shape_env:
assert (
detected_shape_env is detected_fake_mode.shape_env
), "Detected shape env does not match fake mode's shape env"
assert detected_shape_env is detected_fake_mode.shape_env, (
"Detected shape env does not match fake mode's shape env"
)
fake_mode = detected_fake_mode
elif detected_shape_env:
fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True)
@ -864,13 +864,19 @@ def _export_to_aten_ir(
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module(
mod,
fake_params_buffers,
tie_weights=True,
strict=True,
stack_weights=True,
), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(), custom_triton_ops_decomposition_ctx(): # type: ignore[attr-defined]
with (
torch.nn.utils.stateless._reparametrize_module(
mod,
fake_params_buffers,
tie_weights=True,
strict=True,
stack_weights=True,
),
grad_safe_guard,
_ignore_backend_decomps(),
_compiling_state_context(),
custom_triton_ops_decomposition_ctx(),
):
gm, graph_signature = transform(aot_export_module)(
mod,
fake_args,
@ -1229,9 +1235,9 @@ def _get_module_call_graph(
"""
gm: torch.fx.GraphModule = export_artifact.aten.gm
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
module_call_specs: dict[
str, dict[str, TreeSpec]
] = export_artifact.module_call_specs
module_call_specs: dict[str, dict[str, TreeSpec]] = (
export_artifact.module_call_specs
)
in_spec: TreeSpec = export_artifact.in_spec
out_spec: TreeSpec = export_artifact.out_spec
@ -1365,7 +1371,8 @@ def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
).module()
elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator]
traced_callable.owner(), # type: ignore[operator]
(torch._C.ScriptModule, torch.nn.Module),
):
with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
return _export(
@ -1430,9 +1437,9 @@ def _strict_export(
attr = getattr(gm_torch_level, node.target)
# Checks if it is not a HigherOrderOp branch or a module
if not isinstance(attr, torch.nn.Module):
assert (
dynamo_fake_mode is not None
), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
assert dynamo_fake_mode is not None, (
"Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
)
node.meta["val"] = dynamo_fake_mode.from_tensor(
attr, static_shapes=True
)
@ -1749,13 +1756,17 @@ def _export_to_aten_ir_make_fx(
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module(
mod,
fake_params_buffers,
tie_weights=True,
strict=True,
stack_weights=True,
), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined]
with (
torch.nn.utils.stateless._reparametrize_module(
mod,
fake_params_buffers,
tie_weights=True,
strict=True,
stack_weights=True,
),
_ignore_backend_decomps(),
_compiling_state_context(),
):
gm, graph_signature = transform(_make_fx_helper)(
mod,
fake_args,
@ -1944,22 +1955,27 @@ def _non_strict_export(
# We also need to attach dynamo configs as these will be used in HOOs that
# use torch.compile, like cond
dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
dynamo_config[
"do_not_emit_runtime_asserts"
] = False # We want to emit runtime asserts
dynamo_config["do_not_emit_runtime_asserts"] = (
False # We want to emit runtime asserts
)
with fake_mode, _NonStrictTorchFunctionHandler(), tracing(
tx
), torch._dynamo.config.patch(dynamo_config):
with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
patched_mod,
new_fake_args,
new_fake_kwargs,
new_fake_constant_attrs,
map_fake_to_real,
), _fakify_module_inputs(
fake_args, fake_kwargs, fake_mode
), _override_builtin_ops():
with (
fake_mode,
_NonStrictTorchFunctionHandler(),
tracing(tx),
torch._dynamo.config.patch(dynamo_config),
):
with (
_fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
patched_mod,
new_fake_args,
new_fake_kwargs,
new_fake_constant_attrs,
map_fake_to_real,
),
_fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
_override_builtin_ops(),
):
aten_export_artifact = _to_aten_func( # type: ignore[operator]
patched_mod,
new_fake_args,

View File

@ -666,7 +666,7 @@ class ShapesCollection:
Example::
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection()
@ -682,7 +682,7 @@ class ShapesCollection:
Example::
args = ({"x": tensor_x, "others": [int_x, int_y]})
args = {"x": tensor_x, "others": [int_x, int_y]}
# Wrap all ints with _IntWrapper
mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)
@ -700,18 +700,18 @@ class ShapesCollection:
self._shapes = {}
def __setitem__(self, t, shape):
assert isinstance(
t, (torch.Tensor, _IntWrapper)
), f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}"
assert isinstance(t, (torch.Tensor, _IntWrapper)), (
f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}"
)
# TODO(avik): check that shape is indeed a Shape
t_id = id(t)
if t_id in self._shapes:
_shape = self._shapes[t_id]
assert (
shape == _shape
), f"Shapes assigned to input do not match: expected {_shape}, got {shape}"
assert shape == _shape, (
f"Shapes assigned to input do not match: expected {_shape}, got {shape}"
)
else:
self._shapes[id(t)] = shape
@ -766,7 +766,7 @@ class AdditionalInputs:
Example::
args0, kwargs0 = ... # example inputs for export
args0, kwargs0 = ... # example inputs for export
# other representative inputs that the exported program will run on
dynamic_shapes = torch.export.AdditionalInputs()
@ -786,9 +786,9 @@ class AdditionalInputs:
"""
assert type(args) is tuple, f"Representative args {args} must be a tuple"
assert (
kwargs is None or type(kwargs) is dict
), f"Representative kwargs {kwargs} must be None or a dict"
assert kwargs is None or type(kwargs) is dict, (
f"Representative kwargs {kwargs} must be None or a dict"
)
self._examples.append((args, kwargs))
def dynamic_shapes(self, m, args, kwargs=None):
@ -1075,7 +1075,8 @@ def _process_dynamic_shapes(
i,
dim.__name__,
StrictMinMaxConstraint(
vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined]
vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined]
warn_only=False,
),
)
else:
@ -1085,7 +1086,8 @@ def _process_dynamic_shapes(
i,
dim.__name__,
StrictMinMaxConstraint(
vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined]
vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined]
warn_only=False,
),
)
return constraint
@ -1161,7 +1163,7 @@ def _process_dynamic_shapes(
def _get_dim_name_mapping(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
):
name_to_dim = {}
for dim in tree_flatten(

View File

@ -137,16 +137,11 @@ class _ExportPackage:
"decoder": ExportMethod(
overloads={
"prefill": ExportedProgram(...),
"decode": ExportedProgram(...)
"decode": ExportedProgram(...),
},
fallbacks=[]
fallbacks=[],
),
"encoder": ExportMethod(
overloads={},
fallbacks=[
ExportedProgram(...)
]
)
"encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]),
},
)
```
@ -212,15 +207,18 @@ class _ExportPackage:
```
package = ExportPackage()
def prefill(x, xa, kv_cache):
assert x.shape[1] == 3
assert kv_cache == {}
def decode(x, xa, kv_cache):
assert x.shape[1] > 1
assert len(kv_cache) > 0
return {...} # dynamic shape specs here
exporter = (
package.exporter(decoder)
.define_overload("prefill", prefill)

View File

@ -272,7 +272,7 @@ def _override_composite_implicit_decomp(cia_ops_to_callable):
def _split_decomp_table_to_cia_and_python_decomp(
decomp_table: dict[torch._ops.OperatorBase, Callable]
decomp_table: dict[torch._ops.OperatorBase, Callable],
) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]:
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
cia_ops_to_callable = {}
@ -443,9 +443,14 @@ def _decompose_and_get_gm_with_new_signature_constants(
tx = TracingContext(fake_mode)
with fake_mode, _override_composite_implicit_decomp(
cia_to_decomp,
), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx):
with (
fake_mode,
_override_composite_implicit_decomp(
cia_to_decomp,
),
_enable_graph_inputs_of_type_nn_module(ep.example_inputs),
tracing(tx),
):
retracing_args_unwrapped = pytree.tree_unflatten(
retracing_args, mod._in_spec
)
@ -573,9 +578,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
if decompose_custom_triton_ops
else _disable_custom_triton_op_functional_decomposition
)
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
cia_to_decomp
), custom_triton_ops_decomposition_ctx():
with (
_ignore_backend_decomps(),
fake_mode,
_override_composite_implicit_decomp(cia_to_decomp),
custom_triton_ops_decomposition_ctx(),
):
gm, graph_signature = aot_export_module(
ep.graph_module,
fake_args,
@ -1514,9 +1522,9 @@ class ExportedProgram:
if node.op != "placeholder":
break
assert i < len(
old_signature.input_specs
), "Number of inputs changed after transformation"
assert i < len(old_signature.input_specs), (
"Number of inputs changed after transformation"
)
old_input_spec = old_signature.input_specs[i]
arg = (
old_input_spec.arg
@ -1539,9 +1547,9 @@ class ExportedProgram:
new_output_specs = []
for i, node in enumerate(output_node.args[0]):
assert i < len(
old_signature.output_specs
), "Number of outputs changed after transformation"
assert i < len(old_signature.output_specs), (
"Number of outputs changed after transformation"
)
old_output_spec = old_signature.output_specs[i]
arg = (
old_output_spec.arg
@ -1599,9 +1607,9 @@ class ExportedProgram:
# TODO: remove this
@final
def _validate(self):
assert (
len(self.verifiers) > 0
), "ExportedProgram must have at least one verifier."
assert len(self.verifiers) > 0, (
"ExportedProgram must have at least one verifier."
)
for v in self.verifiers:
v().check(self)

View File

@ -95,9 +95,9 @@ class InputSpec:
def __post_init__(self):
if self.kind == InputKind.BUFFER:
assert (
self.persistent is not None
), "Failed to specify persistent flag on BUFFER."
assert self.persistent is not None, (
"Failed to specify persistent flag on BUFFER."
)
assert isinstance(
self.arg,
(
@ -187,15 +187,17 @@ class ExportGraphSignature:
self.my_parameter = nn.Parameter(torch.tensor(2.0))
# Define two buffers
self.register_buffer('my_buffer1', torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0))
self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0) # In-place addition
self.my_buffer2.add_(1.0) # In-place addition
return output
@ -520,9 +522,9 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
# For const outputs we just directly return this
return ConstantArgument(name="", value=node)
assert (
"val" in node.meta
), f"{node} is not a constant or a node with a 'val' metadata field"
assert "val" in node.meta, (
f"{node} is not a constant or a node with a 'val' metadata field"
)
val = node.meta["val"]
if node.name in token_names:
return TokenArgument(name=node.name)
@ -565,9 +567,21 @@ def _convert_to_export_graph_signature(
user_outputs = set(graph_signature.user_outputs)
buffer_mutations = graph_signature.buffers_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate
grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr]
grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr]
loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr]
grad_params = (
graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
if is_joint
else {}
)
grad_user_inputs = (
graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr]
if is_joint
else {}
)
loss_output = (
graph_signature.backward_signature.loss_output # type: ignore[union-attr]
if is_joint
else None
)
input_tokens = graph_signature.input_tokens
output_tokens = graph_signature.output_tokens

View File

@ -155,9 +155,9 @@ class PT2ArchiveReader:
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
assert (
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
), "Invalid archive format"
assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, (
"Invalid archive format"
)
def __enter__(self) -> "PT2ArchiveReader":
return self

View File

@ -104,9 +104,9 @@ def _assign_attr(
assert isinstance(from_obj, torch.Tensor)
to_module.register_buffer(field, from_obj, persistent=persistent)
elif attr_kind == _AttrKind.CONSTANT:
assert not isinstance(
from_obj, FakeScriptObject
), "FakeScriptObject should only exist during tracing."
assert not isinstance(from_obj, FakeScriptObject), (
"FakeScriptObject should only exist during tracing."
)
assert isinstance(
from_obj,
(
@ -461,9 +461,9 @@ class UnflattenedModule(torch.nn.Module):
# add constants that are aliased and don't appear in graph signature
for const_name, const in export_module.constants.items():
if const_name not in consts_targets:
assert (
id(const) in consts_map
), "Constants should be either aliased or appear in graph signature"
assert id(const) in consts_map, (
"Constants should be either aliased or appear in graph signature"
)
ph_name, _ = consts_map[id(const)][0]
add_to_consts_map(id(const), ph_name, const_name)
added_params_buffers.add(s.target)
@ -1041,9 +1041,9 @@ class _ModuleFrame:
if arg.name in self.seen_nodes:
flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
self.node_to_placeholder[
self.seen_nodes[arg.name]
] = flat_arg_node
self.node_to_placeholder[self.seen_nodes[arg.name]] = (
flat_arg_node
)
with self.parent.graph.inserting_before(self.parent_call_module):
input_nodes: list[Optional[torch.fx.Node]] = []
@ -1125,8 +1125,7 @@ class _ModuleFrame:
if x in self.node_to_placeholder:
return self.node_to_placeholder[x]
elif (
x.op == "placeholder"
or self.module_call_graph.get(self.fqn) is None
x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
# allow placeholder creation if we are not preserving module call signature
):
self.add_placeholder(x)

View File

@ -82,9 +82,7 @@ Example:
>>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
>>> torch.fft.fft(t)
tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j])
""".format(
**common_args
),
""".format(**common_args),
)
ifft = _add_docstr(
@ -125,9 +123,7 @@ Example:
>>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
>>> torch.fft.ifft(t)
tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j])
""".format(
**common_args
),
""".format(**common_args),
)
fft2 = _add_docstr(
@ -188,9 +184,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
>>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
ifft2 = _add_docstr(
@ -243,9 +237,7 @@ Example:
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
>>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
fftn = _add_docstr(
@ -305,9 +297,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
>>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
ifftn = _add_docstr(
@ -359,9 +349,7 @@ Example:
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
>>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
rfft = _add_docstr(
@ -417,9 +405,7 @@ Example:
Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted.
At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair,
and therefore must always be real-valued.
""".format(
**common_args
),
""".format(**common_args),
)
irfft = _add_docstr(
@ -496,9 +482,7 @@ Example:
>>> roundtrip = torch.fft.irfft(T, t.numel())
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
rfft2 = _add_docstr(
@ -565,9 +549,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
>>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
irfft2 = _add_docstr(
@ -649,9 +631,7 @@ Example:
torch.Size([10, 9])
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
rfftn = _add_docstr(
@ -718,9 +698,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
>>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
irfftn = _add_docstr(
@ -801,9 +779,7 @@ Example:
torch.Size([10, 9])
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format(
**common_args
),
""".format(**common_args),
)
hfft = _add_docstr(
@ -894,9 +870,7 @@ Example:
>>> torch.fft.hfft(T[:3])
tensor([0.1250, 0.2809, 0.6250, 0.9691])
""".format(
**common_args
),
""".format(**common_args),
)
ihfft = _add_docstr(
@ -951,9 +925,7 @@ Example:
>>> torch.fft.ifft(t)
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
-0.5000+0.6882j])
""".format(
**common_args
),
""".format(**common_args),
)
hfft2 = _add_docstr(
@ -1025,9 +997,7 @@ Example:
>>> torch.allclose(roundtrip, T)
True
""".format(
**common_args
),
""".format(**common_args),
)
ihfft2 = _add_docstr(
@ -1092,9 +1062,7 @@ Example:
>>> torch.allclose(t, two_ffts)
True
""".format(
**common_args
),
""".format(**common_args),
)
hfftn = _add_docstr(
@ -1187,9 +1155,7 @@ Example:
>>> torch.allclose(roundtrip, T)
True
""".format(
**common_args
),
""".format(**common_args),
)
ihfftn = _add_docstr(
@ -1259,9 +1225,7 @@ Example:
>>> torch.allclose(ihfftn, two_iffts)
True
""".format(
**common_args
),
""".format(**common_args),
)
fftfreq = _add_docstr(
@ -1310,9 +1274,7 @@ Example:
>>> torch.fft.fftfreq(4)
tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
""".format(
**factory_common_args
),
""".format(**factory_common_args),
)
rfftfreq = _add_docstr(
@ -1361,9 +1323,7 @@ Example:
>>> torch.fft.fftfreq(4)
tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
""".format(
**factory_common_args
),
""".format(**factory_common_args),
)
fftshift = _add_docstr(

View File

@ -271,9 +271,9 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
...
ValueError: foo
"""
assert isinstance(
result, Exception
), f"{result} is of type {type(result)}, not an Exception."
assert isinstance(result, Exception), (
f"{result} is of type {type(result)}, not an Exception."
)
def raise_error(fut_result):
raise fut_result

View File

@ -253,9 +253,9 @@ class _TensorPickleData:
for k in MetaTensorDesc._UNSERIALIZABLE:
if k in ("fake_mode", "view_func"):
continue
assert (
getattr(self.metadata, k) is None
), f"not None: {k}: {getattr(self.metadata, k)}"
assert getattr(self.metadata, k) is None, (
f"not None: {k}: {getattr(self.metadata, k)}"
)
def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
# TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?

View File

@ -755,9 +755,9 @@ class Tracer(TracerBase):
self.root = root
assert hasattr(
type(root), self.traced_func_name
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
assert hasattr(type(root), self.traced_func_name), (
f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
)
fn = getattr(type(root), self.traced_func_name)
self.root_module_name = root._get_name()
@ -1164,9 +1164,9 @@ def _maybe_revert_all_patches():
finally:
if current_patcher is not None:
patches_made = current_patcher.reapply_all_patches()
assert (
patches_made == patches_removed
), "CURRENT_PATCHER was changed during a revert_all_patches"
assert patches_made == patches_removed, (
"CURRENT_PATCHER was changed during a revert_all_patches"
)
def _patch_wrapped_functions(patcher: _Patcher):
@ -1248,9 +1248,9 @@ def wrap(fn_or_name: Union[str, Callable]):
assert not isinstance(fn_or_name, str) # to make mypy happy
fn_name = fn_or_name.__name__
else:
assert isinstance(
fn_or_name, str
), "fn_or_name must be a global function or string name"
assert isinstance(fn_or_name, str), (
"fn_or_name must be a global function or string name"
)
fn_name = fn_or_name
currentframe = inspect.currentframe()
@ -1308,7 +1308,9 @@ def symbolic_trace(
return out
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
f = fx.symbolic_trace(
f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
)
assert f({"a": 1, "b": 2, "c": 4}) == 7

View File

@ -450,9 +450,9 @@ class Partitioner:
device = find_device_based_on_size(node)
occupied_devices.append(device)
# Update partition and its left mem size
partition_to_left_mem_bytes[
partition
] = device.available_mem_bytes
partition_to_left_mem_bytes[partition] = (
device.available_mem_bytes
)
# Update available mem for the current partition
partition.logical_device_ids.append(device.logical_id)
else:
@ -475,9 +475,9 @@ class Partitioner:
total_size_of_input_nodes = get_extra_size_of(
node, partition.nodes
)
partition_to_left_mem_bytes[
partition
] = device.available_mem_bytes
partition_to_left_mem_bytes[partition] = (
device.available_mem_bytes
)
partition.logical_device_ids.append(device.logical_id)
partition.add_node(node)
partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
@ -509,9 +509,9 @@ class Partitioner:
no_device_partitions,
) = get_device_partition_stats(self.partitions, self.devices)
assert (
len(no_device_partitions) == 0
), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
assert len(no_device_partitions) == 0, (
f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
)
# Devices that hold partitions
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]

View File

@ -368,12 +368,12 @@ def optimize_for_inference(
supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None:
assert (
sample_parameter.dtype == torch.float
), "this pass is only for torch.float modules"
assert sample_parameter.device == torch.device(
"cpu"
), "this pass is only for CPU modules"
assert sample_parameter.dtype == torch.float, (
"this pass is only for torch.float modules"
)
assert sample_parameter.device == torch.device("cpu"), (
"this pass is only for CPU modules"
)
elif node.op == "call_function":
if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES

View File

@ -182,22 +182,19 @@ def is_sym_node(node: _HasMeta) -> bool:
@overload
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None:
...
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
@overload
def set_proxy_slot(
obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy
) -> None:
...
) -> None: ...
@overload
def set_proxy_slot(
obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType
) -> None:
...
) -> None: ...
def set_proxy_slot(
@ -256,8 +253,7 @@ _PySymProxyType = Thunk[Proxy]
def get_proxy_slot(
obj: Tensor,
tracer: _ProxyTracer,
) -> _ProxyTensor:
...
) -> _ProxyTensor: ...
@overload
@ -265,8 +261,7 @@ def get_proxy_slot(
obj: Tensor,
tracer: _ProxyTracer,
default: U,
) -> Union[_ProxyTensor, U]:
...
) -> Union[_ProxyTensor, U]: ...
@overload
@ -275,16 +270,14 @@ def get_proxy_slot(
tracer: _ProxyTracer,
default: U,
transform: Callable[[_ProxyTensor], R],
) -> Union[R, U]:
...
) -> Union[R, U]: ...
@overload
def get_proxy_slot(
obj: _AnyScriptObjectType,
tracer: _ProxyTracer,
) -> Proxy:
...
) -> Proxy: ...
@overload
@ -292,8 +285,7 @@ def get_proxy_slot(
obj: _AnyScriptObjectType,
tracer: _ProxyTracer,
default: U,
) -> Union[Proxy, U]:
...
) -> Union[Proxy, U]: ...
@overload
@ -302,16 +294,14 @@ def get_proxy_slot(
tracer: _ProxyTracer,
default: U,
transform: Callable[[Proxy], R],
) -> Union[R, U]:
...
) -> Union[R, U]: ...
@overload
def get_proxy_slot(
obj: PySymType,
tracer: _ProxyTracer,
) -> _PySymProxyType:
...
) -> _PySymProxyType: ...
@overload
@ -319,8 +309,7 @@ def get_proxy_slot(
obj: PySymType,
tracer: _ProxyTracer,
default: T,
) -> Union[T, _PySymProxyType]:
...
) -> Union[T, _PySymProxyType]: ...
@overload
@ -329,8 +318,7 @@ def get_proxy_slot(
tracer: _ProxyTracer,
default: U,
transform: Callable[[_PySymProxyType], R],
) -> Union[R, U]:
...
) -> Union[R, U]: ...
# the default argument is what to return if the slot is not set.
@ -717,22 +705,21 @@ def fetch_sym_proxy(
@overload
def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]:
...
def fetch_object_proxy(
tracer: _ProxyTracer, t: Tensor
) -> Union[_ProxyTensor, Tensor]: ...
@overload
def fetch_object_proxy(
tracer: _ProxyTracer, t: _AnyScriptObjectType
) -> Union[Proxy, _AnyScriptObjectType]:
...
) -> Union[Proxy, _AnyScriptObjectType]: ...
@overload
def fetch_object_proxy(
tracer: _ProxyTracer, t: PySymType
) -> Union[_PySymProxyType, PySymType]:
...
) -> Union[_PySymProxyType, PySymType]: ...
def fetch_object_proxy(
@ -815,7 +802,10 @@ def proxy_call(
if func is torch.ops.aten.is_nonzero.default:
with proxy_mode:
torch._check(args[0].numel() == 1, lambda: "Boolean value of Tensor with more than one value is ambiguous") # type: ignore[attr-defined]
torch._check(
args[0].numel() == 1, # type: ignore[attr-defined]
lambda: "Boolean value of Tensor with more than one value is ambiguous",
)
return (args[0] != 0).item() # type: ignore[attr-defined]
tracer = proxy_mode.tracer
@ -1079,18 +1069,15 @@ class PythonKeyTracer(Tracer):
return super().create_arg(a) # type: ignore[return-value]
@overload
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
...
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ...
@overload
def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]:
...
def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ...
@overload
def unwrap_proxy(
self, e: _AnyScriptObjectType
) -> Union[Proxy, _AnyScriptObjectType]:
...
) -> Union[Proxy, _AnyScriptObjectType]: ...
def unwrap_proxy(self, e: T) -> object:
if isinstance(e, Tensor):
@ -1608,7 +1595,10 @@ class DecompositionInterpreter(fx.Interpreter):
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
def placeholder(
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
@ -1617,7 +1607,10 @@ class DecompositionInterpreter(fx.Interpreter):
return out
def get_attr(
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
@ -1627,7 +1620,10 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode.
def output(
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
out = super().output(target, args, kwargs) # type: ignore[arg-type]
@ -1989,14 +1985,14 @@ class _MakefxTracer:
# adding new modes in _MakefxTracer.
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext()
self.proxy_function_mode: Union[
nullcontext, PreDispatchTorchFunctionMode
] = nullcontext()
self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = (
nullcontext()
)
self.fx_tracer: Optional[PythonKeyTracer] = None
self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext()
self.torch_fn_metadata_mode: Union[
nullcontext, TorchFunctionMetadataMode
] = nullcontext()
self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
nullcontext()
)
self.stack_trace = stack_trace
def _checkpoint_modes(self) -> list[Any]:
@ -2071,9 +2067,9 @@ class _MakefxTracer:
allow_non_fake_inputs=self._allow_non_fake_inputs,
shape_env=shape_env,
)
assert (
fake_tensor_mode.shape_env is not None
), "shape_env should be set if tracing with 'symbolic'"
assert fake_tensor_mode.shape_env is not None, (
"shape_env should be set if tracing with 'symbolic'"
)
self.fake_tensor_mode = fake_tensor_mode
else:
if not self.tracing_mode == "real":
@ -2161,9 +2157,9 @@ class _MakefxTracer:
return self.fake_tensor_mode.from_tensor(x, source=source)
# NB: don't match on bools
elif type(x) is int and self.tracing_mode == "symbolic":
assert (
self.fake_tensor_mode.shape_env is not None
), "shape_env should be set if tracing with 'symbolic'"
assert self.fake_tensor_mode.shape_env is not None, (
"shape_env should be set if tracing with 'symbolic'"
)
return self.fake_tensor_mode.shape_env.create_symintnode(
self.fake_tensor_mode.shape_env.create_symbol(
x, source, positive=None
@ -2176,9 +2172,9 @@ class _MakefxTracer:
self.fake_tensor_mode, x
)
assert not isinstance(
x, FakeScriptObject
), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
assert not isinstance(x, FakeScriptObject), (
f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
)
return x
wrap_fn_map = {
@ -2344,9 +2340,9 @@ def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
torch._C._TorchDispatchModeKey.PROXY
)
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
assert (
pre_dispatch_mode is None or mode is None
), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
assert pre_dispatch_mode is None or mode is None, (
f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
)
return pre_dispatch_mode or mode

View File

@ -460,7 +460,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
# Here, we allow the value of each field to be mapped, so that we appropriately
# compare the two values.
def compare_vars(
map_value: Callable[[str, Any], Any]
map_value: Callable[[str, Any], Any],
) -> list[tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars)

View File

@ -103,7 +103,7 @@ class AnnotateTypesWithSchema(Transformer):
for i, atom in enumerate(atoms):
if not hasattr(module_itr, atom):
raise RuntimeError(
f'Node referenced nonextent target {".".join(atoms[:i])}!'
f"Node referenced nonextent target {'.'.join(atoms[:i])}!"
)
module_itr = getattr(module_itr, atom)

View File

@ -149,9 +149,9 @@ class SymNode:
# This is technically not TV, but this assert is expensive so
# let's only do it when we're already doing expensive things
computed_hint = compute_hint()
assert (
hint == computed_hint
), f"{hint} != {computed_hint} (for {self.expr})"
assert hint == computed_hint, (
f"{hint} != {computed_hint} (for {self.expr})"
)
else:
hint = compute_hint()
self._hint = hint
@ -460,7 +460,9 @@ class SymNode:
return self.float_pow(other)
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
to_node(self, 1)
) # type: ignore[attr-defined]
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace

View File

@ -182,7 +182,9 @@ CURRENT_NODE_KEY = "current_node"
def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None:
log.debug(
"lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() # type: ignore[attr-defined]
"lru_cache_stats %s: %s",
wrapped_f.__name__, # type: ignore[attr-defined]
wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined]
)
@ -497,9 +499,9 @@ def check_consistent(new: _T, old: _T) -> None:
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
# NB: bool is subclass of int
elif isinstance(new, scalar_types) and not isinstance(new, bool):
assert isinstance(old, scalar_types) and not isinstance(
old, bool
), f"{old} != {new}"
assert isinstance(old, scalar_types) and not isinstance(old, bool), (
f"{old} != {new}"
)
torch._check(old == new, lambda: f"{old} != {new} (old != new)")
@ -629,9 +631,9 @@ def rebind_unbacked(
raw_u1 = new_raw_u1
if not isinstance(raw_u1, sympy.Symbol):
assert (
not raw_u1.free_symbols
), f"should have been constant, but got {raw_u1}"
assert not raw_u1.free_symbols, (
f"should have been constant, but got {raw_u1}"
)
continue
# The old and new could be the same if you improperly hit the memo
@ -1975,12 +1977,12 @@ class EqualityConstraint(Constraint):
def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]:
assert isinstance(
symbolic_context, SymbolicContext
), "Invalid symbolic_context object"
assert (
type(symbolic_context) is not SymbolicContext
), "Illegal usage of symbolic_context ABC"
assert isinstance(symbolic_context, SymbolicContext), (
"Invalid symbolic_context object"
)
assert type(symbolic_context) is not SymbolicContext, (
"Illegal usage of symbolic_context ABC"
)
return True
@ -2519,9 +2521,9 @@ def _lru_cache(
prior_version = self._version_counter
prior_key = self._get_key()
else:
assert (
prior_key == self._get_key()
), "ShapeEnv cache key changed without version being updated!"
assert prior_key == self._get_key(), (
"ShapeEnv cache key changed without version being updated!"
)
return fn_cache(self, *args, **kwargs)
@ -2772,9 +2774,9 @@ class DynamicDimConstraintPrinter(PythonPrinter):
def _print_Symbol(self, expr: sympy.Symbol) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr))
assert self.symbol_to_source.get(
expr
), f"Unknown symbol {expr} created by constraints solver"
assert self.symbol_to_source.get(expr), (
f"Unknown symbol {expr} created by constraints solver"
)
return self.symbol_to_source[expr][0].name()
@ -2792,9 +2794,9 @@ class DimConstraints:
source_name_to_debug_name: Mapping[str, str],
) -> None:
# We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: dict[
sympy.Symbol, set[SympyBoolean]
] = defaultdict(set)
self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = (
defaultdict(set)
)
# Among them, we prioritize solving for a free variable that has equalities.
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
# and removing a symbol from the former => removing it from the latter.
@ -2877,9 +2879,10 @@ class DimConstraints:
# With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
# would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
base, divisor = args
base, divisor = self.rewrite_with_congruences(
s, base
), self.rewrite_with_congruences(s, divisor)
base, divisor = (
self.rewrite_with_congruences(s, base),
self.rewrite_with_congruences(s, divisor),
)
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
self._var_to_val
)
@ -2896,9 +2899,10 @@ class DimConstraints:
# NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
# and eliminating b % d as above.
base, divisor = args
base, divisor = self.rewrite_with_congruences(
s, base
), self.rewrite_with_congruences(s, divisor)
base, divisor = (
self.rewrite_with_congruences(s, base),
self.rewrite_with_congruences(s, divisor),
)
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
self._var_to_val
)
@ -3060,9 +3064,9 @@ class DimConstraints:
(arg for arg in solution.args if isinstance(arg, sympy.Eq)),
solution,
)
assert isinstance(
solution, sympy.Eq
), f"Expected an equality constraint for {s}, got {solution}"
assert isinstance(solution, sympy.Eq), (
f"Expected an equality constraint for {s}, got {solution}"
)
symbol, val = solution.args
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# because this is univariate, the solution is a specialization
@ -3340,7 +3344,8 @@ class DimConstraints:
"max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index]
}
if not _check_same_range(
result, name_to_dim[mroot] # type: ignore[index, arg-type]
result,
name_to_dim[mroot], # type: ignore[index, arg-type]
): # ignore if unchanged
modified_root_values[mroot] = result # type: ignore[index]
break
@ -4124,9 +4129,9 @@ class ShapeEnv:
if not isinstance(b, SymInt):
assert a == b
else:
assert isinstance(
b.node.expr, sympy.Symbol
), "constraining non-Symbols NYI"
assert isinstance(b.node.expr, sympy.Symbol), (
"constraining non-Symbols NYI"
)
assert b.node.shape_env is self
self.replacements[b.node.expr] = sympy.Integer(a)
else:
@ -4139,9 +4144,9 @@ class ShapeEnv:
self.replacements[a.node.expr] = sympy.Integer(b)
else:
assert a.node.shape_env is b.node.shape_env
assert isinstance(
b.node.expr, sympy.Symbol
), "constraining non-Symbols NYI"
assert isinstance(b.node.expr, sympy.Symbol), (
"constraining non-Symbols NYI"
)
new_var = self._find(a.node.expr)
self.replacements[b.node.expr] = new_var
@ -4234,9 +4239,9 @@ class ShapeEnv:
# If translation validation is enabled, all arguments must have its
# own FX node.
assert all(
a is not None for a in args
), f"missing arg in FX graph ({op.__name__}): {args}"
assert all(a is not None for a in args), (
f"missing arg in FX graph ({op.__name__}): {args}"
)
node = self.fx_node_cache[node_key] = self.graph.call_function(op, args)
self.name_to_node[node.name] = node
@ -4354,9 +4359,9 @@ class ShapeEnv:
source: Source,
symbolic_context: SymbolicContext,
) -> list[sympy.Expr]:
assert all(
not is_symbolic(val) for val in tensor_size
), f"Expect size to be a plain tuple of ints but got {tensor_size}"
assert all(not is_symbolic(val) for val in tensor_size), (
f"Expect size to be a plain tuple of ints but got {tensor_size}"
)
from torch._dynamo.source import TensorProperty, TensorPropertySource
_assert_symbol_context(symbolic_context)
@ -4398,7 +4403,11 @@ class ShapeEnv:
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
) -> tuple[
tuple[IntLikeType, ...],
tuple[IntLikeType, ...],
IntLikeType,
]:
"""
Returns a list of symbolic sizes and strides for the given tensor.
We try our best to express stride in terms of the sizes, so as to not
@ -4463,9 +4472,9 @@ class ShapeEnv:
) -> IntLikeType:
assert isinstance(maybe_sym, (int, torch.SymInt))
if is_symbolic(maybe_sym):
assert (
maybe_sym.node.shape_env is not self
), "expect the symbol is created from an shape env other than current one."
assert maybe_sym.node.shape_env is not self, (
"expect the symbol is created from an shape env other than current one."
)
return maybe_sym.node.require_hint()
return maybe_sym
@ -4481,7 +4490,11 @@ class ShapeEnv:
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
) -> tuple[
tuple[IntLikeType, ...],
tuple[IntLikeType, ...],
IntLikeType,
]:
dim = len(ex_size)
# Reimplement the legacy behavior
@ -5045,9 +5058,9 @@ class ShapeEnv:
sloc,
)
else:
self.var_to_range[
sympy_expr
] = self._default_unspecified_value_range()
self.var_to_range[sympy_expr] = (
self._default_unspecified_value_range()
)
self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc)
# Small performance optimization: if we have a min-max constraint,
@ -5238,9 +5251,9 @@ class ShapeEnv:
shape_env = replay_shape_env_events(self.events)
self.check_equal(shape_env)
assert len(placeholders) == len(
sources
), f"len({placeholders}) != len({sources})"
assert len(placeholders) == len(sources), (
f"len({placeholders}) != len({sources})"
)
Tensorlike = (torch.Tensor, FakeTensorMeta)
def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext:
@ -5336,9 +5349,9 @@ class ShapeEnv:
symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
list
)
symbol_to_constraints: defaultdict[
sympy.Symbol, set[Constraint]
] = collections.defaultdict(set)
symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = (
collections.defaultdict(set)
)
constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
printers: list[_ShapeGuardPrinter] = []
@ -6528,7 +6541,7 @@ class ShapeEnv:
f"Caused by: {sloc}\n"
'For more information, run with TORCH_LOGS="dynamic"\n'
"For extended logs when we create symbols, also add "
f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n"
f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n'
"If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
"For more debugging help, see "
"https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n"
@ -6662,9 +6675,9 @@ class ShapeEnv:
)
self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
tgt_bound = self.bound_sympy(tgt)
assert tgt_bound.issubset(
src_bound
), f"{tgt_bound=} not a subset of {src_bound=}"
assert tgt_bound.issubset(src_bound), (
f"{tgt_bound=} not a subset of {src_bound=}"
)
# TODO: Should we propagate size-like-ness?
#
@ -6751,9 +6764,9 @@ class ShapeEnv:
for source in self.var_to_sources.get(a, []):
if user_tb:
self.user_specialization_stacks[source] = user_tb
self.framework_specialization_stacks[
source
] = CapturedTraceback.extract(cpp=True)
self.framework_specialization_stacks[source] = (
CapturedTraceback.extract(cpp=True)
)
if config.print_specializations:
self.log.warning(
@ -6820,9 +6833,9 @@ class ShapeEnv:
free = list(expr.free_symbols)
assert (
len(free) > 0
), f"The expression should not be static by this point: {expr}"
assert len(free) > 0, (
f"The expression should not be static by this point: {expr}"
)
# In case of really gnarly expression, we don't blow up
if len(free) > 5:
return

View File

@ -203,9 +203,7 @@ try:
return _Z3Ops.to_real(result) if cast_result_to_real else result
def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If(
self.floor(number) < number, self.floor(number + 1), number
) # type: ignore[return-value]
return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value]
def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]
@ -363,9 +361,9 @@ try:
return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
# Adds the Z3 expression corresponding to the first argument
# as a validator input.
assert (
len(args) == 1
), f"expected 1 argument on assertion. Got: {len(args)} "
assert len(args) == 1, (
f"expected 1 argument on assertion. Got: {len(args)} "
)
self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
# Translates SymPy expressions into Z3 expressions.
@ -536,9 +534,9 @@ try:
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
z3expr = SympyToZ3(self).run(e)
assert isinstance(
z3expr, z3.BoolRef
), f"expected boolean expression. Got: {z3expr}"
assert isinstance(z3expr, z3.BoolRef), (
f"expected boolean expression. Got: {z3expr}"
)
return z3expr
def add_source_expr(self, e: z3.BoolRef) -> None:

View File

@ -449,7 +449,7 @@ class CodeGen:
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]'
return f"{origin_typename}[{','.join(args)}]"
else:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
@ -573,7 +573,7 @@ class CodeGen:
summary_str = parsed_stack_trace.get_summary_str()
else:
summary_str = ""
body.append(f'\n {dim(f"# {summary_str}")}\n')
body.append(f"\n {dim(f'# {summary_str}')}\n")
elif prev_stacktrace != "":
prev_stacktrace = ""
no_stacktrace_msg = "# No stacktrace found for following nodes"
@ -842,7 +842,7 @@ class _PyTreeCodeGen(CodeGen):
if len(has_annotation) > 0:
fn_definition += "\n " + "".join(has_annotation) + "\n"
fn_definition += f"""
{', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return fn_definition
def generate_output(self, output_args):
@ -1877,7 +1877,9 @@ class Graph:
# through `insert_pdb`:
gm.graph.on_generate_code(
lambda current_trans: (
lambda body: insert_pdb(current_trans(body) if current_trans else body)
lambda body: insert_pdb(
current_trans(body) if current_trans else body
)
)
)
@ -1916,7 +1918,7 @@ class Graph:
@contextmanager
def _override_sym_repr(
override: Callable[["torch.types.PySymType"], str]
override: Callable[["torch.types.PySymType"], str],
) -> Iterator[None]:
tmp = CodeGen._sym_repr
try:

View File

@ -324,9 +324,9 @@ def _print_readable(
colored=False,
):
graph = module.graph
assert graph is not None and isinstance(
graph, torch.fx.Graph
), "print_readable must be used on a module with a graph"
assert graph is not None and isinstance(graph, torch.fx.Graph), (
"print_readable must be used on a module with a graph"
)
verbose_python_code = graph.python_code(
root_module="self",
@ -873,9 +873,9 @@ class {module_name}(torch.nn.Module):
for node in self.graph.nodes
if "stack_trace" in node.meta
}
dict_without_graph[
"_graphmodule_graph_node_meta_stack_trace"
] = node_meta_stack_trace
dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = (
node_meta_stack_trace
)
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
python_code = self.recompile()

View File

@ -51,7 +51,9 @@ class Interpreter:
method equivalents). We could subclass Interpreter like so::
class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
def call_function(
self, target: Target, args: Tuple, kwargs: Dict
) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs)
@ -405,7 +407,7 @@ class Interpreter:
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}"
f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
@ -468,14 +470,20 @@ class Transformer(Interpreter):
class NegSigmSwapXformer(Transformer):
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self,
target: "Target",
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs)
def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self,
target: "Target",
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
) -> Any:
if target == "neg":
call_self, *args_tail = args

View File

@ -514,9 +514,9 @@ class Node(_NodeBase):
idx (int): The index of the element in ``self.args`` to be inserted before.
arg (Argument): The new argument value to insert into ``args``
"""
assert (
0 <= idx <= len(self.args)
), "insert_args index must be between 0 and len(self.args)"
assert 0 <= idx <= len(self.args), (
"insert_args index must be between 0 and len(self.args)"
)
args_left = self.args[:idx]
args_right = self.args[idx:]
@ -747,13 +747,13 @@ class Node(_NodeBase):
# Check if an impure module.
if self.op == "call_module":
assert (
self.graph.owning_module is not None
), "self.graph.owning_module not set for purity check"
assert self.graph.owning_module is not None, (
"self.graph.owning_module not set for purity check"
)
target_mod = self.graph.owning_module.get_submodule(self.target)
assert (
target_mod is not None
), f"Did not find expected submodule target {self.target}"
assert target_mod is not None, (
f"Did not find expected submodule target {self.target}"
)
return getattr(target_mod, "_is_impure", False)
return False

View File

@ -770,9 +770,9 @@ class _MinimizerBase:
node_name = node.name
if node_name is not None and isinstance(node_name, tuple):
node_name = node_name[0]
assert node_name is not None and isinstance(
node_name, str
), f"minimize: node_name: {node_name}"
assert node_name is not None and isinstance(node_name, str), (
f"minimize: node_name: {node_name}"
)
report.append(f"Add node: {node_name}")

View File

@ -93,9 +93,9 @@ def loop_pass(
predicate (Callable[Object, bool], optional):
"""
assert (n_iter is not None) ^ (
predicate is not None
), "Exactly one of `n_iter`or `predicate` must be specified."
assert (n_iter is not None) ^ (predicate is not None), (
"Exactly one of `n_iter`or `predicate` must be specified."
)
@wraps(base_pass)
def new_pass(source):

View File

@ -397,7 +397,9 @@ def insert_deferred_runtime_asserts(
nn_module_stack=node.meta.get("nn_module_stack"),
),
):
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
expr_to_proxy[sym_expr] = _sympy_interp(
expr_to_proxy, sym_expr
) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
node.replace_all_uses_with(hash_node)

View File

@ -199,9 +199,9 @@ def split_by_tags(
mx = max((c.order for c in upstream_components), default=0)
# Expect the component for `node` has higher order then its upstream components.
assert (
comp.order >= mx
), f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
assert comp.order >= mx, (
f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
)
# Map a input of `node` to nodes in the component's graph.
def remap_func(x):

View File

@ -36,9 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList:
if indegree_map[n] == 0:
candidates.put(n)
assert len(nodes) == len(
sorted_nodes
), "topological sorted nodes doesn't have same length as input nodes"
assert len(nodes) == len(sorted_nodes), (
"topological sorted nodes doesn't have same length as input nodes"
)
return sorted_nodes
@ -127,13 +127,13 @@ def fuse_as_graphmodule(
# assumption: nodes are already sorted in topo order
for node in nodes:
assert (
node.graph.owning_module is gm
), f"{node} doesn't belong to passed in graph module {gm._get_name()}"
assert node.graph.owning_module is gm, (
f"{node} doesn't belong to passed in graph module {gm._get_name()}"
)
assert not node._erased, f"{node} has been removed from owning graph"
assert (
node in gm.graph._find_nodes_lookup_table
), f"{node} is not found in graph module {gm._get_name()}"
assert node in gm.graph._find_nodes_lookup_table, (
f"{node} is not found in graph module {gm._get_name()}"
)
# validates partition doesn't introduce dependency circles in the graph
assert validate_partition(nodes), "Invalid partition, found dependency cycles"

View File

@ -96,9 +96,9 @@ class SubgraphMatcher:
for node in pattern.nodes:
if node.op != "output":
assert (
len(node.users) > 0
), "SubgraphMatcher cannot be initialized with an pattern with dead code"
assert len(node.users) > 0, (
"SubgraphMatcher cannot be initialized with an pattern with dead code"
)
# TODO: assert pattern is a connected graph
@ -192,9 +192,9 @@ class SubgraphMatcher:
return non_overlapping_matches
def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
assert not (
isinstance(pn, Node) and isinstance(gn, Node)
), "pn and gn cannot both be Node"
assert not (isinstance(pn, Node) and isinstance(gn, Node)), (
"pn and gn cannot both be Node"
)
if isinstance(pn, Node) and not isinstance(gn, Node):
if pn.op == "placeholder":

View File

@ -18,17 +18,17 @@ def _split_to_graph_and_name_node_map(
if n.op == "output":
assert gm._out_spec is not None
output = tree_unflatten(n.args[0], gm._out_spec)
assert isinstance(
output, tuple
), "Expecting the pattern graph to return a tuple"
assert (
len(output) >= 2
), "Expecting the pattern graph to have at least two outputs"
assert isinstance(output, tuple), (
"Expecting the pattern graph to return a tuple"
)
assert len(output) >= 2, (
"Expecting the pattern graph to have at least two outputs"
)
*out, name_node_map = output
flattened, out_spec = tree_flatten(out)
assert isinstance(
name_node_map, dict
), "Expecting the input graph to have a dict output as the last element"
assert isinstance(name_node_map, dict), (
"Expecting the input graph to have a dict output as the last element"
)
n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
@ -53,12 +53,14 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
relu = F.relu(conv)
return relu, {"conv": conv, "relu": relu}
def target_graph(x, weight):
conv = F.conv2d(x, weight)
relu = F.relu(conv)
relu *= 2
return relu
pattern_gm = export_for_training(pattern, example_inputs).module()
target_gm = export_for_training(target_graph, example_inputs).module()
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)

View File

@ -654,9 +654,9 @@ class MetaProxy(Proxy):
meta_proxy = arg
break
assert (
meta_proxy is not None
), "No MetaProxy found in arguments, but one is expected."
assert meta_proxy is not None, (
"No MetaProxy found in arguments, but one is expected."
)
proxy = super().__torch_function__(orig_method, types, args, kwargs)
with meta_proxy.fake_mode:
@ -739,14 +739,14 @@ for method in magic_methods:
return tracer.create_proxy("call_function", target, args, kwargs)
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
as_magic = f"__{method.strip('_')}__"
setattr(Proxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
method_name = f"__r{orig_method_name.strip('_')}__"
def impl(self, rhs):
target = getattr(operator, orig_method_name)

View File

@ -307,9 +307,9 @@ def _replace_pattern(
elif callable(replacement):
common_replacement_graph = symbolic_trace(replacement).graph
else:
assert (
replacement_callback is not None
), "Must provide either a replacement GraphModule or a replacement callback"
assert replacement_callback is not None, (
"Must provide either a replacement GraphModule or a replacement callback"
)
common_replacement_graph = None
# As we progressively replace nodes, we'll need to keep track of how the match results should change
@ -322,9 +322,9 @@ def _replace_pattern(
match, original_graph, pattern_graph
)
else:
assert (
common_replacement_graph is not None
), "Must provide either a replacement GraphModule or a replacement callback"
assert common_replacement_graph is not None, (
"Must provide either a replacement GraphModule or a replacement callback"
)
replacement_graph = common_replacement_graph
replacement_placeholders = [
n for n in replacement_graph.nodes if n.op == "placeholder"

View File

@ -18,7 +18,15 @@ from torch.nn.modules.utils import (
_builtin_table: Optional[dict[int, str]] = None
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
_modules_containing_builtins = (
torch,
torch._C._nn,
torch._C._fft, # type: ignore[attr-defined]
torch._C._linalg, # type: ignore[attr-defined]
torch._C._nested, # type: ignore[attr-defined]
torch._C._sparse, # type: ignore[attr-defined]
torch._C._special, # type: ignore[attr-defined]
)
_builtin_ops = [
# Pairs of (function, op_name)
@ -94,7 +102,10 @@ _builtin_ops = [
(torch.autograd.grad, "aten::grad"),
(torch.autograd.backward, "aten::backward"),
(torch._C._infer_size, "aten::_infer_size"),
(torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined]
(
torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined]
"aten::_no_grad_embedding_renorm_",
),
(torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
(torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
(torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),

View File

@ -4,9 +4,9 @@ from torch._ops import OpOverload, OpOverloadPacket
def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
assert not isinstance(
op, OpOverloadPacket
), f"Must pass specific op overload, not overload packet, found {op}"
assert not isinstance(op, OpOverloadPacket), (
f"Must pass specific op overload, not overload packet, found {op}"
)
assert isinstance(op, OpOverload)
torch._C._jit_register_decomposition_for_schema(op._schema, graph)

View File

@ -23,13 +23,13 @@ def check_decomposition_has_type_annotations(f):
inspect_empty = inspect._empty # type: ignore[attr-defined]
sig = inspect.signature(f)
for param in sig.parameters.values():
assert (
param.annotation != inspect_empty
), f"No signature on param {param.name} for function {f.name}"
assert param.annotation != inspect_empty, (
f"No signature on param {param.name} for function {f.name}"
)
assert (
sig.return_annotation != inspect_empty
), f"No return annotation for function {f.name}"
assert sig.return_annotation != inspect_empty, (
f"No return annotation for function {f.name}"
)
def signatures_match(decomposition_sig, torch_op_sig):
@ -75,9 +75,9 @@ def register_decomposition(
assert isinstance(aten_op, torch._ops.OpOverload)
# Need unique name for jit function serialization
assert (
f.__name__ not in function_name_set
), f"Duplicated function name {f.__name__}"
assert f.__name__ not in function_name_set, (
f"Duplicated function name {f.__name__}"
)
function_name_set.add(f.__name__)
scripted_func = torch.jit.script(f)

View File

@ -588,9 +588,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
# recursively scripting them.
for name, sub_concrete_type in concrete_type.get_modules():
orig_value = getattr(nn_module, name)
assert isinstance(
orig_value, Module
), f"Expected Module but got {type(orig_value)}"
assert isinstance(orig_value, Module), (
f"Expected Module but got {type(orig_value)}"
)
module_type = sub_concrete_type.jit_type
if isinstance(module_type, torch._C.InterfaceType):
# use the interface inference rule to compile the module

View File

@ -318,10 +318,10 @@ class ScriptMeta(type):
else:
return infer_methods_to_compile(module)
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(
self, make_stubs, share_types=not added_methods_in_init
self.__dict__["_actual_script_module"] = (
torch.jit._recursive.create_script_module(
self, make_stubs, share_types=not added_methods_in_init
)
)
# Delete the Python attributes that now shadow the ScriptModule

View File

@ -280,15 +280,15 @@ def max_pool2d(
dilation: list[int],
ceil_mode: bool,
):
assert (
len(kernel_size) == 1 or len(kernel_size) == 2
), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
assert len(kernel_size) == 1 or len(kernel_size) == 2, (
"max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
)
kH = kernel_size[0]
kW = kH if len(kernel_size) == 1 else kernel_size[1]
assert (
len(stride) == 0 or len(stride) == 1 or len(stride) == 2
), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, (
"max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
)
dH = kH if len(stride) == 0 else stride[0]
if len(stride) == 0:
dW = kW
@ -297,15 +297,15 @@ def max_pool2d(
else:
dW = stride[1]
assert (
len(padding) == 1 or len(padding) == 2
), "max_pool2d: padding must either be a single int, or a tuple of two ints"
assert len(padding) == 1 or len(padding) == 2, (
"max_pool2d: padding must either be a single int, or a tuple of two ints"
)
padH = padding[0]
padW = padH if len(padding) == 1 else padding[1]
assert (
len(dilation) == 1 or len(dilation) == 2
), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
assert len(dilation) == 1 or len(dilation) == 2, (
"max_pool2d: dilation must be either a single int, or a tuple of two ints"
)
dilationH = dilation[0]
dilationW = dilationH if len(dilation) == 1 else dilation[1]
@ -367,17 +367,17 @@ def upsample_nearest2d(
assert 0, "Either output_size or scale_factors must be presented"
if output_size is not None:
assert (
scale_factors is None
), "Must specify exactly one of output_size and scale_factors"
assert scale_factors is None, (
"Must specify exactly one of output_size and scale_factors"
)
assert len(output_size) == 2
out.append(output_size[0])
out.append(output_size[1])
if scale_factors is not None:
assert (
output_size is None
), "Must specify exactly one of output_size and scale_factors"
assert output_size is None, (
"Must specify exactly one of output_size and scale_factors"
)
assert len(scale_factors) == 2
out.append(int(input[2] * scale_factors[0]))
out.append(int(input[3] * scale_factors[1]))
@ -540,9 +540,9 @@ def check_cat_shape_except_dim(
assert first_dims == second_dims, "Tensors must have same number of dimensions"
for dim in range(0, first_dims):
if dim != dimension:
assert (
first[dim] == second[dim]
), "Sizes of tensors must match except in dimension"
assert first[dim] == second[dim], (
"Sizes of tensors must match except in dimension"
)
def cat(tensors: list[list[int]], dim: int):
@ -1088,9 +1088,9 @@ def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
if len(self) == 0:
result: list[int] = []
else:
assert (
k <= self[dim]
), f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
assert k <= self[dim], (
f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
)
result = _copy(self)
result[dim] = k
return result, result

View File

@ -1205,7 +1205,10 @@ def trace_module(
# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight}
inputs = {
"forward": example_forward_input,
"weighted_kernel_sum": example_weight,
}
module = torch.jit.trace_module(n, inputs)
"""

View File

@ -309,14 +309,14 @@ defined as ``prod(x[:i])``.""",
operation_args, operation_kwargs = args_and_kwargs[func.__name__]
arg_declarations = [
"\n ".join(
argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines()
argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines()
)
for a in operation_args
]
kwarg_declarations = [
"\n ".join(
argument_declarations.get(
a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.'
a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD."
)
.format(default=a.split("=", 1)[1])
.splitlines()
@ -745,9 +745,9 @@ def _sparse_csr_segment_reduction_helper(
) -> Tensor:
# Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
# FIXME: when dense dimensions are implemented for CSR tensors
assert (
keepdim
), "reduction operations on CSR tensors with keepdim=False is unsupported"
assert keepdim, (
"reduction operations on CSR tensors with keepdim=False is unsupported"
)
reduce = op.__name__
valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
if reduce not in valid_reductions:
@ -781,9 +781,9 @@ def _sparse_csr_segment_reduction_helper(
)
new_shape = [1, mask_input.size(1)]
else:
assert (
dims[0] == 1
), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
assert dims[0] == 1, (
"Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
)
# all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
# except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
new_crow_indices = torch.cat(
@ -1598,9 +1598,9 @@ def _std_var(
mask: Optional[Tensor],
take_sqrt: Optional[bool],
) -> Tensor:
assert (
unbiased is None or correction_opt is None
), "Only one of unbiased and correction may be given"
assert unbiased is None or correction_opt is None, (
"Only one of unbiased and correction may be given"
)
correction = 1.0
if unbiased is not None:
correction = 1.0 if unbiased else 0.0
@ -1636,7 +1636,11 @@ def _std_var(
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
else:
total = sum(
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined]
x * x.conj(),
dim,
keepdim=keepdim,
dtype=compute_dtype,
mask=inmask, # type: ignore[possibly-undefined]
)
if not keepdim:
count = count.reshape(total.shape)

View File

@ -25,7 +25,7 @@ def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
>>> # xdoctest: +SKIP
>>> from torch.masked import MaskedTensor
>>> data = torch.arange(6).reshape(2,3)
>>> data = torch.arange(6).reshape(2, 3)
>>> mask = torch.tensor([[True, False, False], [True, True, False]])
>>> mt = MaskedTensor(data, mask)
>>> is_masked_tensor(mt)

View File

@ -5,6 +5,7 @@ Metal is Apple's API for programming metal GPU (graphics processor unit). Using
performance can be achieved, by running work on the metal GPU(s).
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
"""
from typing import Union
import torch

View File

@ -198,7 +198,7 @@ def snapshot() -> dict[str, Any]:
def attach_out_of_memory_observer(
observer: Callable[[int, int, int, int], None]
observer: Callable[[int, int, int, int], None],
) -> None:
r"""Attach an out-of-memory observer to MTIA memory allocator"""
torch._C._mtia_attachOutOfMemoryObserver(observer)

View File

@ -14,6 +14,7 @@ memory.
Because of the similarity of APIs we do not document most of this package
contents, and we recommend referring to very good docs of the original module.
"""
import multiprocessing
import sys