[BE][PYFMT] migrate PYFMT for torch/[e-n]*/ to ruff format (#144553)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553
Approved by: https://github.com/ezyang
ghstack dependencies: #144551
This commit is contained in:
Xuehai Pan
2025-06-16 14:35:28 +08:00
committed by PyTorch MergeBot
parent 95cb42c45d
commit 2e0e08588e
48 changed files with 548 additions and 497 deletions

View File

@ -60,7 +60,6 @@ USE_BLACK_FILELIST = re.compile(
"torch/[b-c]*/**", "torch/[b-c]*/**",
# torch/d*/** # torch/d*/**
# torch/[e-m]*/** # torch/[e-m]*/**
"torch/[e-m]*/**",
# torch/optim/** # torch/optim/**
# torch/[p-z]*/** # torch/[p-z]*/**
"torch/[p-z]*/**", "torch/[p-z]*/**",

View File

@ -358,22 +358,24 @@ def save(
import torch import torch
import io import io
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x + 10 return x + 10
ep = torch.export.export(MyModule(), (torch.randn(5),)) ep = torch.export.export(MyModule(), (torch.randn(5),))
# Save to file # Save to file
torch.export.save(ep, 'exported_program.pt2') torch.export.save(ep, "exported_program.pt2")
# Save to io.BytesIO buffer # Save to io.BytesIO buffer
buffer = io.BytesIO() buffer = io.BytesIO()
torch.export.save(ep, buffer) torch.export.save(ep, buffer)
# Save with extra files # Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')} extra_files = {"foo.txt": b"bar".decode("utf-8")}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
""" """
if not isinstance(ep, ExportedProgram): if not isinstance(ep, ExportedProgram):
@ -427,18 +429,18 @@ def load(
import io import io
# Load ExportedProgram from file # Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2') ep = torch.export.load("exported_program.pt2")
# Load ExportedProgram from io.BytesIO object # Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f: with open("exported_program.pt2", "rb") as f:
buffer = io.BytesIO(f.read()) buffer = io.BytesIO(f.read())
buffer.seek(0) buffer.seek(0)
ep = torch.export.load(buffer) ep = torch.export.load(buffer)
# Load with extra files. # Load with extra files.
extra_files = {'foo.txt': ''} # values will be replaced with data extra_files = {"foo.txt": ""} # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files) ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files['foo.txt']) print(extra_files["foo.txt"])
print(ep(torch.randn(5))) print(ep(torch.randn(5)))
""" """
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
@ -572,24 +574,29 @@ def register_dataclass(
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class InputDataClass: class InputDataClass:
feature: torch.Tensor feature: torch.Tensor
bias: int bias: int
@dataclass @dataclass
class OutputDataClass: class OutputDataClass:
res: torch.Tensor res: torch.Tensor
torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass) torch.export.register_dataclass(OutputDataClass)
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def forward(self, x: InputDataClass) -> OutputDataClass: def forward(self, x: InputDataClass) -> OutputDataClass:
res = x.feature + x.bias res = x.feature + x.bias
return OutputDataClass(res=res) return OutputDataClass(res=res)
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
print(ep) print(ep)
""" """

View File

@ -43,7 +43,7 @@ def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str])
continue continue
res += f""" res += f"""
File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}""" # type: ignore[index] File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index]
res += f"\n {stack[-1]['loc']}" res += f"\n {stack[-1]['loc']}"
return res return res
@ -327,13 +327,13 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
# We don't want to log all expression_created logs, only # We don't want to log all expression_created logs, only
# the ones that are relevant to the # the ones that are relevant to the
# guards/propagate_real_tensor # guards/propagate_real_tensor
self.expression_created_logs[ self.expression_created_logs[metadata[key]["result_id"]] = (
metadata[key]["result_id"] ExpressionCreatedNode(
] = ExpressionCreatedNode(
metadata[key]["result_id"], metadata[key]["result_id"],
metadata[key].get("argument_ids", []), metadata[key].get("argument_ids", []),
record, record,
) )
)
return return
elif key == "propagate_real_tensors_provenance": elif key == "propagate_real_tensors_provenance":
@ -374,10 +374,13 @@ def draft_export(
capture_structured_log = CaptureStructuredTrace() capture_structured_log = CaptureStructuredTrace()
with torch._functorch.config.patch( with (
torch._functorch.config.patch(
fake_tensor_propagate_real_tensors=True, fake_tensor_propagate_real_tensors=True,
generate_fake_kernels_from_real_mismatches=True, generate_fake_kernels_from_real_mismatches=True,
), capture_structured_log: ),
capture_structured_log,
):
try: try:
new_shapes = None new_shapes = None
ep = _export( ep = _export(
@ -424,11 +427,11 @@ def draft_export(
continue continue
elif log_name == "propagate_real_tensors_provenance": elif log_name == "propagate_real_tensors_provenance":
log_contents[ log_contents["occurrences"] = (
"occurrences" capture_structured_log.log_record.get_log_count(
] = capture_structured_log.log_record.get_log_count(
(log_name, log_contents) (log_name, log_contents)
) )
)
failure_type = FailureType.DATA_DEPENDENT_ERROR failure_type = FailureType.DATA_DEPENDENT_ERROR

View File

@ -26,9 +26,9 @@ def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
if user.op == "output": if user.op == "output":
continue continue
assert ( assert user.op == "call_function" and user.target == operator.getitem, (
user.op == "call_function" and user.target == operator.getitem f"Expected getitem node as user for {node}, instead got {user}"
), f"Expected getitem node as user for {node}, instead got {user}" )
getitem_users.update(list(user.users.keys())) getitem_users.update(list(user.users.keys()))
return getitem_users return getitem_users
@ -63,9 +63,9 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
log.debug("Trying to remove pytrees for module call %s", curr_module_node) log.debug("Trying to remove pytrees for module call %s", curr_module_node)
curr_module_users = list(curr_module_node.users.keys()) curr_module_users = list(curr_module_node.users.keys())
assert ( assert len(curr_module_users) == 1, (
len(curr_module_users) == 1 f"Expected only one user for module node, instead got {list(curr_module_users)}"
), f"Expected only one user for module node, instead got {list(curr_module_users)}" )
flatten_node = curr_module_users[0] flatten_node = curr_module_users[0]
assert ( assert (
flatten_node.op == "call_function" flatten_node.op == "call_function"

View File

@ -268,9 +268,9 @@ def _extract_fake_inputs(gm, args, kwargs):
if detected_fake_mode: if detected_fake_mode:
if detected_shape_env: if detected_shape_env:
assert ( assert detected_shape_env is detected_fake_mode.shape_env, (
detected_shape_env is detected_fake_mode.shape_env "Detected shape env does not match fake mode's shape env"
), "Detected shape env does not match fake mode's shape env" )
fake_mode = detected_fake_mode fake_mode = detected_fake_mode
elif detected_shape_env: elif detected_shape_env:
fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True) fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True)
@ -864,13 +864,19 @@ def _export_to_aten_ir(
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes. # otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module( with (
torch.nn.utils.stateless._reparametrize_module(
mod, mod,
fake_params_buffers, fake_params_buffers,
tie_weights=True, tie_weights=True,
strict=True, strict=True,
stack_weights=True, stack_weights=True,
), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(), custom_triton_ops_decomposition_ctx(): # type: ignore[attr-defined] ),
grad_safe_guard,
_ignore_backend_decomps(),
_compiling_state_context(),
custom_triton_ops_decomposition_ctx(),
):
gm, graph_signature = transform(aot_export_module)( gm, graph_signature = transform(aot_export_module)(
mod, mod,
fake_args, fake_args,
@ -1229,9 +1235,9 @@ def _get_module_call_graph(
""" """
gm: torch.fx.GraphModule = export_artifact.aten.gm gm: torch.fx.GraphModule = export_artifact.aten.gm
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
module_call_specs: dict[ module_call_specs: dict[str, dict[str, TreeSpec]] = (
str, dict[str, TreeSpec] export_artifact.module_call_specs
] = export_artifact.module_call_specs )
in_spec: TreeSpec = export_artifact.in_spec in_spec: TreeSpec = export_artifact.in_spec
out_spec: TreeSpec = export_artifact.out_spec out_spec: TreeSpec = export_artifact.out_spec
@ -1365,7 +1371,8 @@ def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
).module() ).module()
elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator] traced_callable.owner(), # type: ignore[operator]
(torch._C.ScriptModule, torch.nn.Module),
): ):
with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
return _export( return _export(
@ -1430,9 +1437,9 @@ def _strict_export(
attr = getattr(gm_torch_level, node.target) attr = getattr(gm_torch_level, node.target)
# Checks if it is not a HigherOrderOp branch or a module # Checks if it is not a HigherOrderOp branch or a module
if not isinstance(attr, torch.nn.Module): if not isinstance(attr, torch.nn.Module):
assert ( assert dynamo_fake_mode is not None, (
dynamo_fake_mode is not None "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." )
node.meta["val"] = dynamo_fake_mode.from_tensor( node.meta["val"] = dynamo_fake_mode.from_tensor(
attr, static_shapes=True attr, static_shapes=True
) )
@ -1749,13 +1756,17 @@ def _export_to_aten_ir_make_fx(
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes. # otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module( with (
torch.nn.utils.stateless._reparametrize_module(
mod, mod,
fake_params_buffers, fake_params_buffers,
tie_weights=True, tie_weights=True,
strict=True, strict=True,
stack_weights=True, stack_weights=True,
), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] ),
_ignore_backend_decomps(),
_compiling_state_context(),
):
gm, graph_signature = transform(_make_fx_helper)( gm, graph_signature = transform(_make_fx_helper)(
mod, mod,
fake_args, fake_args,
@ -1944,22 +1955,27 @@ def _non_strict_export(
# We also need to attach dynamo configs as these will be used in HOOs that # We also need to attach dynamo configs as these will be used in HOOs that
# use torch.compile, like cond # use torch.compile, like cond
dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG) dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
dynamo_config[ dynamo_config["do_not_emit_runtime_asserts"] = (
"do_not_emit_runtime_asserts" False # We want to emit runtime asserts
] = False # We want to emit runtime asserts )
with fake_mode, _NonStrictTorchFunctionHandler(), tracing( with (
tx fake_mode,
), torch._dynamo.config.patch(dynamo_config): _NonStrictTorchFunctionHandler(),
with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( tracing(tx),
torch._dynamo.config.patch(dynamo_config),
):
with (
_fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
patched_mod, patched_mod,
new_fake_args, new_fake_args,
new_fake_kwargs, new_fake_kwargs,
new_fake_constant_attrs, new_fake_constant_attrs,
map_fake_to_real, map_fake_to_real,
), _fakify_module_inputs( ),
fake_args, fake_kwargs, fake_mode _fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
), _override_builtin_ops(): _override_builtin_ops(),
):
aten_export_artifact = _to_aten_func( # type: ignore[operator] aten_export_artifact = _to_aten_func( # type: ignore[operator]
patched_mod, patched_mod,
new_fake_args, new_fake_args,

View File

@ -666,7 +666,7 @@ class ShapesCollection:
Example:: Example::
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
dim = torch.export.Dim(...) dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes = torch.export.ShapesCollection()
@ -682,7 +682,7 @@ class ShapesCollection:
Example:: Example::
args = ({"x": tensor_x, "others": [int_x, int_y]}) args = {"x": tensor_x, "others": [int_x, int_y]}
# Wrap all ints with _IntWrapper # Wrap all ints with _IntWrapper
mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)
@ -700,18 +700,18 @@ class ShapesCollection:
self._shapes = {} self._shapes = {}
def __setitem__(self, t, shape): def __setitem__(self, t, shape):
assert isinstance( assert isinstance(t, (torch.Tensor, _IntWrapper)), (
t, (torch.Tensor, _IntWrapper) f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}"
), f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" )
# TODO(avik): check that shape is indeed a Shape # TODO(avik): check that shape is indeed a Shape
t_id = id(t) t_id = id(t)
if t_id in self._shapes: if t_id in self._shapes:
_shape = self._shapes[t_id] _shape = self._shapes[t_id]
assert ( assert shape == _shape, (
shape == _shape f"Shapes assigned to input do not match: expected {_shape}, got {shape}"
), f"Shapes assigned to input do not match: expected {_shape}, got {shape}" )
else: else:
self._shapes[id(t)] = shape self._shapes[id(t)] = shape
@ -786,9 +786,9 @@ class AdditionalInputs:
""" """
assert type(args) is tuple, f"Representative args {args} must be a tuple" assert type(args) is tuple, f"Representative args {args} must be a tuple"
assert ( assert kwargs is None or type(kwargs) is dict, (
kwargs is None or type(kwargs) is dict f"Representative kwargs {kwargs} must be None or a dict"
), f"Representative kwargs {kwargs} must be None or a dict" )
self._examples.append((args, kwargs)) self._examples.append((args, kwargs))
def dynamic_shapes(self, m, args, kwargs=None): def dynamic_shapes(self, m, args, kwargs=None):
@ -1075,7 +1075,8 @@ def _process_dynamic_shapes(
i, i,
dim.__name__, dim.__name__,
StrictMinMaxConstraint( StrictMinMaxConstraint(
vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined] vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined]
warn_only=False,
), ),
) )
else: else:
@ -1085,7 +1086,8 @@ def _process_dynamic_shapes(
i, i,
dim.__name__, dim.__name__,
StrictMinMaxConstraint( StrictMinMaxConstraint(
vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined] vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined]
warn_only=False,
), ),
) )
return constraint return constraint
@ -1161,7 +1163,7 @@ def _process_dynamic_shapes(
def _get_dim_name_mapping( def _get_dim_name_mapping(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
): ):
name_to_dim = {} name_to_dim = {}
for dim in tree_flatten( for dim in tree_flatten(

View File

@ -137,16 +137,11 @@ class _ExportPackage:
"decoder": ExportMethod( "decoder": ExportMethod(
overloads={ overloads={
"prefill": ExportedProgram(...), "prefill": ExportedProgram(...),
"decode": ExportedProgram(...) "decode": ExportedProgram(...),
}, },
fallbacks=[] fallbacks=[],
), ),
"encoder": ExportMethod( "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]),
overloads={},
fallbacks=[
ExportedProgram(...)
]
)
}, },
) )
``` ```
@ -212,15 +207,18 @@ class _ExportPackage:
``` ```
package = ExportPackage() package = ExportPackage()
def prefill(x, xa, kv_cache): def prefill(x, xa, kv_cache):
assert x.shape[1] == 3 assert x.shape[1] == 3
assert kv_cache == {} assert kv_cache == {}
def decode(x, xa, kv_cache): def decode(x, xa, kv_cache):
assert x.shape[1] > 1 assert x.shape[1] > 1
assert len(kv_cache) > 0 assert len(kv_cache) > 0
return {...} # dynamic shape specs here return {...} # dynamic shape specs here
exporter = ( exporter = (
package.exporter(decoder) package.exporter(decoder)
.define_overload("prefill", prefill) .define_overload("prefill", prefill)

View File

@ -272,7 +272,7 @@ def _override_composite_implicit_decomp(cia_ops_to_callable):
def _split_decomp_table_to_cia_and_python_decomp( def _split_decomp_table_to_cia_and_python_decomp(
decomp_table: dict[torch._ops.OperatorBase, Callable] decomp_table: dict[torch._ops.OperatorBase, Callable],
) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: ) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]:
all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
cia_ops_to_callable = {} cia_ops_to_callable = {}
@ -443,9 +443,14 @@ def _decompose_and_get_gm_with_new_signature_constants(
tx = TracingContext(fake_mode) tx = TracingContext(fake_mode)
with fake_mode, _override_composite_implicit_decomp( with (
fake_mode,
_override_composite_implicit_decomp(
cia_to_decomp, cia_to_decomp,
), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx): ),
_enable_graph_inputs_of_type_nn_module(ep.example_inputs),
tracing(tx),
):
retracing_args_unwrapped = pytree.tree_unflatten( retracing_args_unwrapped = pytree.tree_unflatten(
retracing_args, mod._in_spec retracing_args, mod._in_spec
) )
@ -573,9 +578,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
if decompose_custom_triton_ops if decompose_custom_triton_ops
else _disable_custom_triton_op_functional_decomposition else _disable_custom_triton_op_functional_decomposition
) )
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( with (
cia_to_decomp _ignore_backend_decomps(),
), custom_triton_ops_decomposition_ctx(): fake_mode,
_override_composite_implicit_decomp(cia_to_decomp),
custom_triton_ops_decomposition_ctx(),
):
gm, graph_signature = aot_export_module( gm, graph_signature = aot_export_module(
ep.graph_module, ep.graph_module,
fake_args, fake_args,
@ -1514,9 +1522,9 @@ class ExportedProgram:
if node.op != "placeholder": if node.op != "placeholder":
break break
assert i < len( assert i < len(old_signature.input_specs), (
old_signature.input_specs "Number of inputs changed after transformation"
), "Number of inputs changed after transformation" )
old_input_spec = old_signature.input_specs[i] old_input_spec = old_signature.input_specs[i]
arg = ( arg = (
old_input_spec.arg old_input_spec.arg
@ -1539,9 +1547,9 @@ class ExportedProgram:
new_output_specs = [] new_output_specs = []
for i, node in enumerate(output_node.args[0]): for i, node in enumerate(output_node.args[0]):
assert i < len( assert i < len(old_signature.output_specs), (
old_signature.output_specs "Number of outputs changed after transformation"
), "Number of outputs changed after transformation" )
old_output_spec = old_signature.output_specs[i] old_output_spec = old_signature.output_specs[i]
arg = ( arg = (
old_output_spec.arg old_output_spec.arg
@ -1599,9 +1607,9 @@ class ExportedProgram:
# TODO: remove this # TODO: remove this
@final @final
def _validate(self): def _validate(self):
assert ( assert len(self.verifiers) > 0, (
len(self.verifiers) > 0 "ExportedProgram must have at least one verifier."
), "ExportedProgram must have at least one verifier." )
for v in self.verifiers: for v in self.verifiers:
v().check(self) v().check(self)

View File

@ -95,9 +95,9 @@ class InputSpec:
def __post_init__(self): def __post_init__(self):
if self.kind == InputKind.BUFFER: if self.kind == InputKind.BUFFER:
assert ( assert self.persistent is not None, (
self.persistent is not None "Failed to specify persistent flag on BUFFER."
), "Failed to specify persistent flag on BUFFER." )
assert isinstance( assert isinstance(
self.arg, self.arg,
( (
@ -187,12 +187,14 @@ class ExportGraphSignature:
self.my_parameter = nn.Parameter(torch.tensor(2.0)) self.my_parameter = nn.Parameter(torch.tensor(2.0))
# Define two buffers # Define two buffers
self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0)) self.register_buffer("my_buffer2", torch.tensor(4.0))
def forward(self, x1, x2): def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method # Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1) # Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0) # In-place addition self.my_buffer2.add_(1.0) # In-place addition
@ -520,9 +522,9 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
# For const outputs we just directly return this # For const outputs we just directly return this
return ConstantArgument(name="", value=node) return ConstantArgument(name="", value=node)
assert ( assert "val" in node.meta, (
"val" in node.meta f"{node} is not a constant or a node with a 'val' metadata field"
), f"{node} is not a constant or a node with a 'val' metadata field" )
val = node.meta["val"] val = node.meta["val"]
if node.name in token_names: if node.name in token_names:
return TokenArgument(name=node.name) return TokenArgument(name=node.name)
@ -565,9 +567,21 @@ def _convert_to_export_graph_signature(
user_outputs = set(graph_signature.user_outputs) user_outputs = set(graph_signature.user_outputs)
buffer_mutations = graph_signature.buffers_to_mutate buffer_mutations = graph_signature.buffers_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate user_input_mutations = graph_signature.user_inputs_to_mutate
grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr] grad_params = (
grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr] graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr] if is_joint
else {}
)
grad_user_inputs = (
graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr]
if is_joint
else {}
)
loss_output = (
graph_signature.backward_signature.loss_output # type: ignore[union-attr]
if is_joint
else None
)
input_tokens = graph_signature.input_tokens input_tokens = graph_signature.input_tokens
output_tokens = graph_signature.output_tokens output_tokens = graph_signature.output_tokens

View File

@ -155,9 +155,9 @@ class PT2ArchiveReader:
def __init__(self, archive_path_or_buffer: FileLike): def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
assert ( assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, (
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE "Invalid archive format"
), "Invalid archive format" )
def __enter__(self) -> "PT2ArchiveReader": def __enter__(self) -> "PT2ArchiveReader":
return self return self

View File

@ -104,9 +104,9 @@ def _assign_attr(
assert isinstance(from_obj, torch.Tensor) assert isinstance(from_obj, torch.Tensor)
to_module.register_buffer(field, from_obj, persistent=persistent) to_module.register_buffer(field, from_obj, persistent=persistent)
elif attr_kind == _AttrKind.CONSTANT: elif attr_kind == _AttrKind.CONSTANT:
assert not isinstance( assert not isinstance(from_obj, FakeScriptObject), (
from_obj, FakeScriptObject "FakeScriptObject should only exist during tracing."
), "FakeScriptObject should only exist during tracing." )
assert isinstance( assert isinstance(
from_obj, from_obj,
( (
@ -461,9 +461,9 @@ class UnflattenedModule(torch.nn.Module):
# add constants that are aliased and don't appear in graph signature # add constants that are aliased and don't appear in graph signature
for const_name, const in export_module.constants.items(): for const_name, const in export_module.constants.items():
if const_name not in consts_targets: if const_name not in consts_targets:
assert ( assert id(const) in consts_map, (
id(const) in consts_map "Constants should be either aliased or appear in graph signature"
), "Constants should be either aliased or appear in graph signature" )
ph_name, _ = consts_map[id(const)][0] ph_name, _ = consts_map[id(const)][0]
add_to_consts_map(id(const), ph_name, const_name) add_to_consts_map(id(const), ph_name, const_name)
added_params_buffers.add(s.target) added_params_buffers.add(s.target)
@ -1041,9 +1041,9 @@ class _ModuleFrame:
if arg.name in self.seen_nodes: if arg.name in self.seen_nodes:
flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
self.node_to_placeholder[ self.node_to_placeholder[self.seen_nodes[arg.name]] = (
self.seen_nodes[arg.name] flat_arg_node
] = flat_arg_node )
with self.parent.graph.inserting_before(self.parent_call_module): with self.parent.graph.inserting_before(self.parent_call_module):
input_nodes: list[Optional[torch.fx.Node]] = [] input_nodes: list[Optional[torch.fx.Node]] = []
@ -1125,8 +1125,7 @@ class _ModuleFrame:
if x in self.node_to_placeholder: if x in self.node_to_placeholder:
return self.node_to_placeholder[x] return self.node_to_placeholder[x]
elif ( elif (
x.op == "placeholder" x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
or self.module_call_graph.get(self.fqn) is None
# allow placeholder creation if we are not preserving module call signature # allow placeholder creation if we are not preserving module call signature
): ):
self.add_placeholder(x) self.add_placeholder(x)

View File

@ -82,9 +82,7 @@ Example:
>>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
>>> torch.fft.fft(t) >>> torch.fft.fft(t)
tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j]) tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j])
""".format( """.format(**common_args),
**common_args
),
) )
ifft = _add_docstr( ifft = _add_docstr(
@ -125,9 +123,7 @@ Example:
>>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
>>> torch.fft.ifft(t) >>> torch.fft.ifft(t)
tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]) tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j])
""".format( """.format(**common_args),
**common_args
),
) )
fft2 = _add_docstr( fft2 = _add_docstr(
@ -188,9 +184,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
>>> torch.testing.assert_close(fft2, two_ffts, check_stride=False) >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
ifft2 = _add_docstr( ifft2 = _add_docstr(
@ -243,9 +237,7 @@ Example:
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
>>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False) >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
fftn = _add_docstr( fftn = _add_docstr(
@ -305,9 +297,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
>>> torch.testing.assert_close(fftn, two_ffts, check_stride=False) >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
ifftn = _add_docstr( ifftn = _add_docstr(
@ -359,9 +349,7 @@ Example:
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
>>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False) >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
rfft = _add_docstr( rfft = _add_docstr(
@ -417,9 +405,7 @@ Example:
Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted.
At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair,
and therefore must always be real-valued. and therefore must always be real-valued.
""".format( """.format(**common_args),
**common_args
),
) )
irfft = _add_docstr( irfft = _add_docstr(
@ -496,9 +482,7 @@ Example:
>>> roundtrip = torch.fft.irfft(T, t.numel()) >>> roundtrip = torch.fft.irfft(T, t.numel())
>>> torch.testing.assert_close(roundtrip, t, check_stride=False) >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
rfft2 = _add_docstr( rfft2 = _add_docstr(
@ -565,9 +549,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
>>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False) >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
irfft2 = _add_docstr( irfft2 = _add_docstr(
@ -649,9 +631,7 @@ Example:
torch.Size([10, 9]) torch.Size([10, 9])
>>> torch.testing.assert_close(roundtrip, t, check_stride=False) >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
rfftn = _add_docstr( rfftn = _add_docstr(
@ -718,9 +698,7 @@ Example:
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
>>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False) >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
irfftn = _add_docstr( irfftn = _add_docstr(
@ -801,9 +779,7 @@ Example:
torch.Size([10, 9]) torch.Size([10, 9])
>>> torch.testing.assert_close(roundtrip, t, check_stride=False) >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
""".format( """.format(**common_args),
**common_args
),
) )
hfft = _add_docstr( hfft = _add_docstr(
@ -894,9 +870,7 @@ Example:
>>> torch.fft.hfft(T[:3]) >>> torch.fft.hfft(T[:3])
tensor([0.1250, 0.2809, 0.6250, 0.9691]) tensor([0.1250, 0.2809, 0.6250, 0.9691])
""".format( """.format(**common_args),
**common_args
),
) )
ihfft = _add_docstr( ihfft = _add_docstr(
@ -951,9 +925,7 @@ Example:
>>> torch.fft.ifft(t) >>> torch.fft.ifft(t)
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
-0.5000+0.6882j]) -0.5000+0.6882j])
""".format( """.format(**common_args),
**common_args
),
) )
hfft2 = _add_docstr( hfft2 = _add_docstr(
@ -1025,9 +997,7 @@ Example:
>>> torch.allclose(roundtrip, T) >>> torch.allclose(roundtrip, T)
True True
""".format( """.format(**common_args),
**common_args
),
) )
ihfft2 = _add_docstr( ihfft2 = _add_docstr(
@ -1092,9 +1062,7 @@ Example:
>>> torch.allclose(t, two_ffts) >>> torch.allclose(t, two_ffts)
True True
""".format( """.format(**common_args),
**common_args
),
) )
hfftn = _add_docstr( hfftn = _add_docstr(
@ -1187,9 +1155,7 @@ Example:
>>> torch.allclose(roundtrip, T) >>> torch.allclose(roundtrip, T)
True True
""".format( """.format(**common_args),
**common_args
),
) )
ihfftn = _add_docstr( ihfftn = _add_docstr(
@ -1259,9 +1225,7 @@ Example:
>>> torch.allclose(ihfftn, two_iffts) >>> torch.allclose(ihfftn, two_iffts)
True True
""".format( """.format(**common_args),
**common_args
),
) )
fftfreq = _add_docstr( fftfreq = _add_docstr(
@ -1310,9 +1274,7 @@ Example:
>>> torch.fft.fftfreq(4) >>> torch.fft.fftfreq(4)
tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
""".format( """.format(**factory_common_args),
**factory_common_args
),
) )
rfftfreq = _add_docstr( rfftfreq = _add_docstr(
@ -1361,9 +1323,7 @@ Example:
>>> torch.fft.fftfreq(4) >>> torch.fft.fftfreq(4)
tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
""".format( """.format(**factory_common_args),
**factory_common_args
),
) )
fftshift = _add_docstr( fftshift = _add_docstr(

View File

@ -271,9 +271,9 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
... ...
ValueError: foo ValueError: foo
""" """
assert isinstance( assert isinstance(result, Exception), (
result, Exception f"{result} is of type {type(result)}, not an Exception."
), f"{result} is of type {type(result)}, not an Exception." )
def raise_error(fut_result): def raise_error(fut_result):
raise fut_result raise fut_result

View File

@ -253,9 +253,9 @@ class _TensorPickleData:
for k in MetaTensorDesc._UNSERIALIZABLE: for k in MetaTensorDesc._UNSERIALIZABLE:
if k in ("fake_mode", "view_func"): if k in ("fake_mode", "view_func"):
continue continue
assert ( assert getattr(self.metadata, k) is None, (
getattr(self.metadata, k) is None f"not None: {k}: {getattr(self.metadata, k)}"
), f"not None: {k}: {getattr(self.metadata, k)}" )
def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
# TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?

View File

@ -755,9 +755,9 @@ class Tracer(TracerBase):
self.root = root self.root = root
assert hasattr( assert hasattr(type(root), self.traced_func_name), (
type(root), self.traced_func_name f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" )
fn = getattr(type(root), self.traced_func_name) fn = getattr(type(root), self.traced_func_name)
self.root_module_name = root._get_name() self.root_module_name = root._get_name()
@ -1164,9 +1164,9 @@ def _maybe_revert_all_patches():
finally: finally:
if current_patcher is not None: if current_patcher is not None:
patches_made = current_patcher.reapply_all_patches() patches_made = current_patcher.reapply_all_patches()
assert ( assert patches_made == patches_removed, (
patches_made == patches_removed "CURRENT_PATCHER was changed during a revert_all_patches"
), "CURRENT_PATCHER was changed during a revert_all_patches" )
def _patch_wrapped_functions(patcher: _Patcher): def _patch_wrapped_functions(patcher: _Patcher):
@ -1248,9 +1248,9 @@ def wrap(fn_or_name: Union[str, Callable]):
assert not isinstance(fn_or_name, str) # to make mypy happy assert not isinstance(fn_or_name, str) # to make mypy happy
fn_name = fn_or_name.__name__ fn_name = fn_or_name.__name__
else: else:
assert isinstance( assert isinstance(fn_or_name, str), (
fn_or_name, str "fn_or_name must be a global function or string name"
), "fn_or_name must be a global function or string name" )
fn_name = fn_or_name fn_name = fn_or_name
currentframe = inspect.currentframe() currentframe = inspect.currentframe()
@ -1308,7 +1308,9 @@ def symbolic_trace(
return out return out
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) f = fx.symbolic_trace(
f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
)
assert f({"a": 1, "b": 2, "c": 4}) == 7 assert f({"a": 1, "b": 2, "c": 4}) == 7

View File

@ -450,9 +450,9 @@ class Partitioner:
device = find_device_based_on_size(node) device = find_device_based_on_size(node)
occupied_devices.append(device) occupied_devices.append(device)
# Update partition and its left mem size # Update partition and its left mem size
partition_to_left_mem_bytes[ partition_to_left_mem_bytes[partition] = (
partition device.available_mem_bytes
] = device.available_mem_bytes )
# Update available mem for the current partition # Update available mem for the current partition
partition.logical_device_ids.append(device.logical_id) partition.logical_device_ids.append(device.logical_id)
else: else:
@ -475,9 +475,9 @@ class Partitioner:
total_size_of_input_nodes = get_extra_size_of( total_size_of_input_nodes = get_extra_size_of(
node, partition.nodes node, partition.nodes
) )
partition_to_left_mem_bytes[ partition_to_left_mem_bytes[partition] = (
partition device.available_mem_bytes
] = device.available_mem_bytes )
partition.logical_device_ids.append(device.logical_id) partition.logical_device_ids.append(device.logical_id)
partition.add_node(node) partition.add_node(node)
partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
@ -509,9 +509,9 @@ class Partitioner:
no_device_partitions, no_device_partitions,
) = get_device_partition_stats(self.partitions, self.devices) ) = get_device_partition_stats(self.partitions, self.devices)
assert ( assert len(no_device_partitions) == 0, (
len(no_device_partitions) == 0 f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" )
# Devices that hold partitions # Devices that hold partitions
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]

View File

@ -368,12 +368,12 @@ def optimize_for_inference(
supports_mkldnn = MklSupport.YES supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None) sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None: if sample_parameter is not None:
assert ( assert sample_parameter.dtype == torch.float, (
sample_parameter.dtype == torch.float "this pass is only for torch.float modules"
), "this pass is only for torch.float modules" )
assert sample_parameter.device == torch.device( assert sample_parameter.device == torch.device("cpu"), (
"cpu" "this pass is only for CPU modules"
), "this pass is only for CPU modules" )
elif node.op == "call_function": elif node.op == "call_function":
if node.target in mkldnn_supported: if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES supports_mkldnn = MklSupport.YES

View File

@ -182,22 +182,19 @@ def is_sym_node(node: _HasMeta) -> bool:
@overload @overload
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
...
@overload @overload
def set_proxy_slot( def set_proxy_slot(
obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy
) -> None: ) -> None: ...
...
@overload @overload
def set_proxy_slot( def set_proxy_slot(
obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType
) -> None: ) -> None: ...
...
def set_proxy_slot( def set_proxy_slot(
@ -256,8 +253,7 @@ _PySymProxyType = Thunk[Proxy]
def get_proxy_slot( def get_proxy_slot(
obj: Tensor, obj: Tensor,
tracer: _ProxyTracer, tracer: _ProxyTracer,
) -> _ProxyTensor: ) -> _ProxyTensor: ...
...
@overload @overload
@ -265,8 +261,7 @@ def get_proxy_slot(
obj: Tensor, obj: Tensor,
tracer: _ProxyTracer, tracer: _ProxyTracer,
default: U, default: U,
) -> Union[_ProxyTensor, U]: ) -> Union[_ProxyTensor, U]: ...
...
@overload @overload
@ -275,16 +270,14 @@ def get_proxy_slot(
tracer: _ProxyTracer, tracer: _ProxyTracer,
default: U, default: U,
transform: Callable[[_ProxyTensor], R], transform: Callable[[_ProxyTensor], R],
) -> Union[R, U]: ) -> Union[R, U]: ...
...
@overload @overload
def get_proxy_slot( def get_proxy_slot(
obj: _AnyScriptObjectType, obj: _AnyScriptObjectType,
tracer: _ProxyTracer, tracer: _ProxyTracer,
) -> Proxy: ) -> Proxy: ...
...
@overload @overload
@ -292,8 +285,7 @@ def get_proxy_slot(
obj: _AnyScriptObjectType, obj: _AnyScriptObjectType,
tracer: _ProxyTracer, tracer: _ProxyTracer,
default: U, default: U,
) -> Union[Proxy, U]: ) -> Union[Proxy, U]: ...
...
@overload @overload
@ -302,16 +294,14 @@ def get_proxy_slot(
tracer: _ProxyTracer, tracer: _ProxyTracer,
default: U, default: U,
transform: Callable[[Proxy], R], transform: Callable[[Proxy], R],
) -> Union[R, U]: ) -> Union[R, U]: ...
...
@overload @overload
def get_proxy_slot( def get_proxy_slot(
obj: PySymType, obj: PySymType,
tracer: _ProxyTracer, tracer: _ProxyTracer,
) -> _PySymProxyType: ) -> _PySymProxyType: ...
...
@overload @overload
@ -319,8 +309,7 @@ def get_proxy_slot(
obj: PySymType, obj: PySymType,
tracer: _ProxyTracer, tracer: _ProxyTracer,
default: T, default: T,
) -> Union[T, _PySymProxyType]: ) -> Union[T, _PySymProxyType]: ...
...
@overload @overload
@ -329,8 +318,7 @@ def get_proxy_slot(
tracer: _ProxyTracer, tracer: _ProxyTracer,
default: U, default: U,
transform: Callable[[_PySymProxyType], R], transform: Callable[[_PySymProxyType], R],
) -> Union[R, U]: ) -> Union[R, U]: ...
...
# the default argument is what to return if the slot is not set. # the default argument is what to return if the slot is not set.
@ -717,22 +705,21 @@ def fetch_sym_proxy(
@overload @overload
def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]: def fetch_object_proxy(
... tracer: _ProxyTracer, t: Tensor
) -> Union[_ProxyTensor, Tensor]: ...
@overload @overload
def fetch_object_proxy( def fetch_object_proxy(
tracer: _ProxyTracer, t: _AnyScriptObjectType tracer: _ProxyTracer, t: _AnyScriptObjectType
) -> Union[Proxy, _AnyScriptObjectType]: ) -> Union[Proxy, _AnyScriptObjectType]: ...
...
@overload @overload
def fetch_object_proxy( def fetch_object_proxy(
tracer: _ProxyTracer, t: PySymType tracer: _ProxyTracer, t: PySymType
) -> Union[_PySymProxyType, PySymType]: ) -> Union[_PySymProxyType, PySymType]: ...
...
def fetch_object_proxy( def fetch_object_proxy(
@ -815,7 +802,10 @@ def proxy_call(
if func is torch.ops.aten.is_nonzero.default: if func is torch.ops.aten.is_nonzero.default:
with proxy_mode: with proxy_mode:
torch._check(args[0].numel() == 1, lambda: "Boolean value of Tensor with more than one value is ambiguous") # type: ignore[attr-defined] torch._check(
args[0].numel() == 1, # type: ignore[attr-defined]
lambda: "Boolean value of Tensor with more than one value is ambiguous",
)
return (args[0] != 0).item() # type: ignore[attr-defined] return (args[0] != 0).item() # type: ignore[attr-defined]
tracer = proxy_mode.tracer tracer = proxy_mode.tracer
@ -1079,18 +1069,15 @@ class PythonKeyTracer(Tracer):
return super().create_arg(a) # type: ignore[return-value] return super().create_arg(a) # type: ignore[return-value]
@overload @overload
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ...
...
@overload @overload
def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ...
...
@overload @overload
def unwrap_proxy( def unwrap_proxy(
self, e: _AnyScriptObjectType self, e: _AnyScriptObjectType
) -> Union[Proxy, _AnyScriptObjectType]: ) -> Union[Proxy, _AnyScriptObjectType]: ...
...
def unwrap_proxy(self, e: T) -> object: def unwrap_proxy(self, e: T) -> object:
if isinstance(e, Tensor): if isinstance(e, Tensor):
@ -1608,7 +1595,10 @@ class DecompositionInterpreter(fx.Interpreter):
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
def placeholder( def placeholder(
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] self,
target: str, # type: ignore[override]
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object: ) -> object:
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
@ -1617,7 +1607,10 @@ class DecompositionInterpreter(fx.Interpreter):
return out return out
def get_attr( def get_attr(
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] self,
target: str, # type: ignore[override]
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object: ) -> object:
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
@ -1627,7 +1620,10 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode. # call_function, call_method, call_module get traced automatically by the outer mode.
def output( def output(
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] self,
target: str, # type: ignore[override]
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object: ) -> object:
out = super().output(target, args, kwargs) # type: ignore[arg-type] out = super().output(target, args, kwargs) # type: ignore[arg-type]
@ -1989,14 +1985,14 @@ class _MakefxTracer:
# adding new modes in _MakefxTracer. # adding new modes in _MakefxTracer.
self.fake_tensor_mode: Optional[FakeTensorMode] = None self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext()
self.proxy_function_mode: Union[ self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = (
nullcontext, PreDispatchTorchFunctionMode nullcontext()
] = nullcontext() )
self.fx_tracer: Optional[PythonKeyTracer] = None self.fx_tracer: Optional[PythonKeyTracer] = None
self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext()
self.torch_fn_metadata_mode: Union[ self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
nullcontext, TorchFunctionMetadataMode nullcontext()
] = nullcontext() )
self.stack_trace = stack_trace self.stack_trace = stack_trace
def _checkpoint_modes(self) -> list[Any]: def _checkpoint_modes(self) -> list[Any]:
@ -2071,9 +2067,9 @@ class _MakefxTracer:
allow_non_fake_inputs=self._allow_non_fake_inputs, allow_non_fake_inputs=self._allow_non_fake_inputs,
shape_env=shape_env, shape_env=shape_env,
) )
assert ( assert fake_tensor_mode.shape_env is not None, (
fake_tensor_mode.shape_env is not None "shape_env should be set if tracing with 'symbolic'"
), "shape_env should be set if tracing with 'symbolic'" )
self.fake_tensor_mode = fake_tensor_mode self.fake_tensor_mode = fake_tensor_mode
else: else:
if not self.tracing_mode == "real": if not self.tracing_mode == "real":
@ -2161,9 +2157,9 @@ class _MakefxTracer:
return self.fake_tensor_mode.from_tensor(x, source=source) return self.fake_tensor_mode.from_tensor(x, source=source)
# NB: don't match on bools # NB: don't match on bools
elif type(x) is int and self.tracing_mode == "symbolic": elif type(x) is int and self.tracing_mode == "symbolic":
assert ( assert self.fake_tensor_mode.shape_env is not None, (
self.fake_tensor_mode.shape_env is not None "shape_env should be set if tracing with 'symbolic'"
), "shape_env should be set if tracing with 'symbolic'" )
return self.fake_tensor_mode.shape_env.create_symintnode( return self.fake_tensor_mode.shape_env.create_symintnode(
self.fake_tensor_mode.shape_env.create_symbol( self.fake_tensor_mode.shape_env.create_symbol(
x, source, positive=None x, source, positive=None
@ -2176,9 +2172,9 @@ class _MakefxTracer:
self.fake_tensor_mode, x self.fake_tensor_mode, x
) )
assert not isinstance( assert not isinstance(x, FakeScriptObject), (
x, FakeScriptObject f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." )
return x return x
wrap_fn_map = { wrap_fn_map = {
@ -2344,9 +2340,9 @@ def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
torch._C._TorchDispatchModeKey.PROXY torch._C._TorchDispatchModeKey.PROXY
) )
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
assert ( assert pre_dispatch_mode is None or mode is None, (
pre_dispatch_mode is None or mode is None f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" )
return pre_dispatch_mode or mode return pre_dispatch_mode or mode

View File

@ -460,7 +460,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
# Here, we allow the value of each field to be mapped, so that we appropriately # Here, we allow the value of each field to be mapped, so that we appropriately
# compare the two values. # compare the two values.
def compare_vars( def compare_vars(
map_value: Callable[[str, Any], Any] map_value: Callable[[str, Any], Any],
) -> list[tuple[str, str, str]]: ) -> list[tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars) env1_set, env2_set = set(env1_vars), set(env2_vars)

View File

@ -103,7 +103,7 @@ class AnnotateTypesWithSchema(Transformer):
for i, atom in enumerate(atoms): for i, atom in enumerate(atoms):
if not hasattr(module_itr, atom): if not hasattr(module_itr, atom):
raise RuntimeError( raise RuntimeError(
f'Node referenced nonextent target {".".join(atoms[:i])}!' f"Node referenced nonextent target {'.'.join(atoms[:i])}!"
) )
module_itr = getattr(module_itr, atom) module_itr = getattr(module_itr, atom)

View File

@ -149,9 +149,9 @@ class SymNode:
# This is technically not TV, but this assert is expensive so # This is technically not TV, but this assert is expensive so
# let's only do it when we're already doing expensive things # let's only do it when we're already doing expensive things
computed_hint = compute_hint() computed_hint = compute_hint()
assert ( assert hint == computed_hint, (
hint == computed_hint f"{hint} != {computed_hint} (for {self.expr})"
), f"{hint} != {computed_hint} (for {self.expr})" )
else: else:
hint = compute_hint() hint = compute_hint()
self._hint = hint self._hint = hint
@ -460,7 +460,9 @@ class SymNode:
return self.float_pow(other) return self.float_pow(other)
def is_non_overlapping_and_dense(self, sizes, strides): def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
to_node(self, 1)
) # type: ignore[attr-defined]
def int_(self): def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace return self.guard_int("", 0) # NB: uses Python backtrace

View File

@ -182,7 +182,9 @@ CURRENT_NODE_KEY = "current_node"
def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None:
log.debug( log.debug(
"lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() # type: ignore[attr-defined] "lru_cache_stats %s: %s",
wrapped_f.__name__, # type: ignore[attr-defined]
wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined]
) )
@ -497,9 +499,9 @@ def check_consistent(new: _T, old: _T) -> None:
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
# NB: bool is subclass of int # NB: bool is subclass of int
elif isinstance(new, scalar_types) and not isinstance(new, bool): elif isinstance(new, scalar_types) and not isinstance(new, bool):
assert isinstance(old, scalar_types) and not isinstance( assert isinstance(old, scalar_types) and not isinstance(old, bool), (
old, bool f"{old} != {new}"
), f"{old} != {new}" )
torch._check(old == new, lambda: f"{old} != {new} (old != new)") torch._check(old == new, lambda: f"{old} != {new} (old != new)")
@ -629,9 +631,9 @@ def rebind_unbacked(
raw_u1 = new_raw_u1 raw_u1 = new_raw_u1
if not isinstance(raw_u1, sympy.Symbol): if not isinstance(raw_u1, sympy.Symbol):
assert ( assert not raw_u1.free_symbols, (
not raw_u1.free_symbols f"should have been constant, but got {raw_u1}"
), f"should have been constant, but got {raw_u1}" )
continue continue
# The old and new could be the same if you improperly hit the memo # The old and new could be the same if you improperly hit the memo
@ -1975,12 +1977,12 @@ class EqualityConstraint(Constraint):
def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]: def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]:
assert isinstance( assert isinstance(symbolic_context, SymbolicContext), (
symbolic_context, SymbolicContext "Invalid symbolic_context object"
), "Invalid symbolic_context object" )
assert ( assert type(symbolic_context) is not SymbolicContext, (
type(symbolic_context) is not SymbolicContext "Illegal usage of symbolic_context ABC"
), "Illegal usage of symbolic_context ABC" )
return True return True
@ -2519,9 +2521,9 @@ def _lru_cache(
prior_version = self._version_counter prior_version = self._version_counter
prior_key = self._get_key() prior_key = self._get_key()
else: else:
assert ( assert prior_key == self._get_key(), (
prior_key == self._get_key() "ShapeEnv cache key changed without version being updated!"
), "ShapeEnv cache key changed without version being updated!" )
return fn_cache(self, *args, **kwargs) return fn_cache(self, *args, **kwargs)
@ -2772,9 +2774,9 @@ class DynamicDimConstraintPrinter(PythonPrinter):
def _print_Symbol(self, expr: sympy.Symbol) -> str: def _print_Symbol(self, expr: sympy.Symbol) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr)) assert isinstance(expr, sympy.Symbol), str(type(expr))
assert self.symbol_to_source.get( assert self.symbol_to_source.get(expr), (
expr f"Unknown symbol {expr} created by constraints solver"
), f"Unknown symbol {expr} created by constraints solver" )
return self.symbol_to_source[expr][0].name() return self.symbol_to_source[expr][0].name()
@ -2792,9 +2794,9 @@ class DimConstraints:
source_name_to_debug_name: Mapping[str, str], source_name_to_debug_name: Mapping[str, str],
) -> None: ) -> None:
# We try to solve systems of inequalities with 1 free variable. # We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: dict[ self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = (
sympy.Symbol, set[SympyBoolean] defaultdict(set)
] = defaultdict(set) )
# Among them, we prioritize solving for a free variable that has equalities. # Among them, we prioritize solving for a free variable that has equalities.
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
# and removing a symbol from the former => removing it from the latter. # and removing a symbol from the former => removing it from the latter.
@ -2877,9 +2879,10 @@ class DimConstraints:
# With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
# would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
base, divisor = args base, divisor = args
base, divisor = self.rewrite_with_congruences( base, divisor = (
s, base self.rewrite_with_congruences(s, base),
), self.rewrite_with_congruences(s, divisor) self.rewrite_with_congruences(s, divisor),
)
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
self._var_to_val self._var_to_val
) )
@ -2896,9 +2899,10 @@ class DimConstraints:
# NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
# and eliminating b % d as above. # and eliminating b % d as above.
base, divisor = args base, divisor = args
base, divisor = self.rewrite_with_congruences( base, divisor = (
s, base self.rewrite_with_congruences(s, base),
), self.rewrite_with_congruences(s, divisor) self.rewrite_with_congruences(s, divisor),
)
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
self._var_to_val self._var_to_val
) )
@ -3060,9 +3064,9 @@ class DimConstraints:
(arg for arg in solution.args if isinstance(arg, sympy.Eq)), (arg for arg in solution.args if isinstance(arg, sympy.Eq)),
solution, solution,
) )
assert isinstance( assert isinstance(solution, sympy.Eq), (
solution, sympy.Eq f"Expected an equality constraint for {s}, got {solution}"
), f"Expected an equality constraint for {s}, got {solution}" )
symbol, val = solution.args symbol, val = solution.args
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# because this is univariate, the solution is a specialization # because this is univariate, the solution is a specialization
@ -3340,7 +3344,8 @@ class DimConstraints:
"max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index]
} }
if not _check_same_range( if not _check_same_range(
result, name_to_dim[mroot] # type: ignore[index, arg-type] result,
name_to_dim[mroot], # type: ignore[index, arg-type]
): # ignore if unchanged ): # ignore if unchanged
modified_root_values[mroot] = result # type: ignore[index] modified_root_values[mroot] = result # type: ignore[index]
break break
@ -4124,9 +4129,9 @@ class ShapeEnv:
if not isinstance(b, SymInt): if not isinstance(b, SymInt):
assert a == b assert a == b
else: else:
assert isinstance( assert isinstance(b.node.expr, sympy.Symbol), (
b.node.expr, sympy.Symbol "constraining non-Symbols NYI"
), "constraining non-Symbols NYI" )
assert b.node.shape_env is self assert b.node.shape_env is self
self.replacements[b.node.expr] = sympy.Integer(a) self.replacements[b.node.expr] = sympy.Integer(a)
else: else:
@ -4139,9 +4144,9 @@ class ShapeEnv:
self.replacements[a.node.expr] = sympy.Integer(b) self.replacements[a.node.expr] = sympy.Integer(b)
else: else:
assert a.node.shape_env is b.node.shape_env assert a.node.shape_env is b.node.shape_env
assert isinstance( assert isinstance(b.node.expr, sympy.Symbol), (
b.node.expr, sympy.Symbol "constraining non-Symbols NYI"
), "constraining non-Symbols NYI" )
new_var = self._find(a.node.expr) new_var = self._find(a.node.expr)
self.replacements[b.node.expr] = new_var self.replacements[b.node.expr] = new_var
@ -4234,9 +4239,9 @@ class ShapeEnv:
# If translation validation is enabled, all arguments must have its # If translation validation is enabled, all arguments must have its
# own FX node. # own FX node.
assert all( assert all(a is not None for a in args), (
a is not None for a in args f"missing arg in FX graph ({op.__name__}): {args}"
), f"missing arg in FX graph ({op.__name__}): {args}" )
node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) node = self.fx_node_cache[node_key] = self.graph.call_function(op, args)
self.name_to_node[node.name] = node self.name_to_node[node.name] = node
@ -4354,9 +4359,9 @@ class ShapeEnv:
source: Source, source: Source,
symbolic_context: SymbolicContext, symbolic_context: SymbolicContext,
) -> list[sympy.Expr]: ) -> list[sympy.Expr]:
assert all( assert all(not is_symbolic(val) for val in tensor_size), (
not is_symbolic(val) for val in tensor_size f"Expect size to be a plain tuple of ints but got {tensor_size}"
), f"Expect size to be a plain tuple of ints but got {tensor_size}" )
from torch._dynamo.source import TensorProperty, TensorPropertySource from torch._dynamo.source import TensorProperty, TensorPropertySource
_assert_symbol_context(symbolic_context) _assert_symbol_context(symbolic_context)
@ -4398,7 +4403,11 @@ class ShapeEnv:
source: Source, source: Source,
*, *,
symbolic_context: Optional[SymbolicContext] = None, symbolic_context: Optional[SymbolicContext] = None,
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]: ) -> tuple[
tuple[IntLikeType, ...],
tuple[IntLikeType, ...],
IntLikeType,
]:
""" """
Returns a list of symbolic sizes and strides for the given tensor. Returns a list of symbolic sizes and strides for the given tensor.
We try our best to express stride in terms of the sizes, so as to not We try our best to express stride in terms of the sizes, so as to not
@ -4463,9 +4472,9 @@ class ShapeEnv:
) -> IntLikeType: ) -> IntLikeType:
assert isinstance(maybe_sym, (int, torch.SymInt)) assert isinstance(maybe_sym, (int, torch.SymInt))
if is_symbolic(maybe_sym): if is_symbolic(maybe_sym):
assert ( assert maybe_sym.node.shape_env is not self, (
maybe_sym.node.shape_env is not self "expect the symbol is created from an shape env other than current one."
), "expect the symbol is created from an shape env other than current one." )
return maybe_sym.node.require_hint() return maybe_sym.node.require_hint()
return maybe_sym return maybe_sym
@ -4481,7 +4490,11 @@ class ShapeEnv:
source: Source, source: Source,
*, *,
symbolic_context: Optional[SymbolicContext] = None, symbolic_context: Optional[SymbolicContext] = None,
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]: ) -> tuple[
tuple[IntLikeType, ...],
tuple[IntLikeType, ...],
IntLikeType,
]:
dim = len(ex_size) dim = len(ex_size)
# Reimplement the legacy behavior # Reimplement the legacy behavior
@ -5045,9 +5058,9 @@ class ShapeEnv:
sloc, sloc,
) )
else: else:
self.var_to_range[ self.var_to_range[sympy_expr] = (
sympy_expr self._default_unspecified_value_range()
] = self._default_unspecified_value_range() )
self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc)
# Small performance optimization: if we have a min-max constraint, # Small performance optimization: if we have a min-max constraint,
@ -5238,9 +5251,9 @@ class ShapeEnv:
shape_env = replay_shape_env_events(self.events) shape_env = replay_shape_env_events(self.events)
self.check_equal(shape_env) self.check_equal(shape_env)
assert len(placeholders) == len( assert len(placeholders) == len(sources), (
sources f"len({placeholders}) != len({sources})"
), f"len({placeholders}) != len({sources})" )
Tensorlike = (torch.Tensor, FakeTensorMeta) Tensorlike = (torch.Tensor, FakeTensorMeta)
def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext:
@ -5336,9 +5349,9 @@ class ShapeEnv:
symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
list list
) )
symbol_to_constraints: defaultdict[ symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = (
sympy.Symbol, set[Constraint] collections.defaultdict(set)
] = collections.defaultdict(set) )
constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
printers: list[_ShapeGuardPrinter] = [] printers: list[_ShapeGuardPrinter] = []
@ -6528,7 +6541,7 @@ class ShapeEnv:
f"Caused by: {sloc}\n" f"Caused by: {sloc}\n"
'For more information, run with TORCH_LOGS="dynamic"\n' 'For more information, run with TORCH_LOGS="dynamic"\n'
"For extended logs when we create symbols, also add " "For extended logs when we create symbols, also add "
f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n'
"If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
"For more debugging help, see " "For more debugging help, see "
"https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n"
@ -6662,9 +6675,9 @@ class ShapeEnv:
) )
self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
tgt_bound = self.bound_sympy(tgt) tgt_bound = self.bound_sympy(tgt)
assert tgt_bound.issubset( assert tgt_bound.issubset(src_bound), (
src_bound f"{tgt_bound=} not a subset of {src_bound=}"
), f"{tgt_bound=} not a subset of {src_bound=}" )
# TODO: Should we propagate size-like-ness? # TODO: Should we propagate size-like-ness?
# #
@ -6751,9 +6764,9 @@ class ShapeEnv:
for source in self.var_to_sources.get(a, []): for source in self.var_to_sources.get(a, []):
if user_tb: if user_tb:
self.user_specialization_stacks[source] = user_tb self.user_specialization_stacks[source] = user_tb
self.framework_specialization_stacks[ self.framework_specialization_stacks[source] = (
source CapturedTraceback.extract(cpp=True)
] = CapturedTraceback.extract(cpp=True) )
if config.print_specializations: if config.print_specializations:
self.log.warning( self.log.warning(
@ -6820,9 +6833,9 @@ class ShapeEnv:
free = list(expr.free_symbols) free = list(expr.free_symbols)
assert ( assert len(free) > 0, (
len(free) > 0 f"The expression should not be static by this point: {expr}"
), f"The expression should not be static by this point: {expr}" )
# In case of really gnarly expression, we don't blow up # In case of really gnarly expression, we don't blow up
if len(free) > 5: if len(free) > 5:
return return

View File

@ -203,9 +203,7 @@ try:
return _Z3Ops.to_real(result) if cast_result_to_real else result return _Z3Ops.to_real(result) if cast_result_to_real else result
def ceil(self, number: z3.ArithRef) -> z3.ArithRef: def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If( return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value]
self.floor(number) < number, self.floor(number + 1), number
) # type: ignore[return-value]
def trunc(self, number: z3.ArithRef) -> z3.ArithRef: def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value] return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]
@ -363,9 +361,9 @@ try:
return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
# Adds the Z3 expression corresponding to the first argument # Adds the Z3 expression corresponding to the first argument
# as a validator input. # as a validator input.
assert ( assert len(args) == 1, (
len(args) == 1 f"expected 1 argument on assertion. Got: {len(args)} "
), f"expected 1 argument on assertion. Got: {len(args)} " )
self.validator.add_source_expr(args[0]) # type: ignore[arg-type] self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
# Translates SymPy expressions into Z3 expressions. # Translates SymPy expressions into Z3 expressions.
@ -536,9 +534,9 @@ try:
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
z3expr = SympyToZ3(self).run(e) z3expr = SympyToZ3(self).run(e)
assert isinstance( assert isinstance(z3expr, z3.BoolRef), (
z3expr, z3.BoolRef f"expected boolean expression. Got: {z3expr}"
), f"expected boolean expression. Got: {z3expr}" )
return z3expr return z3expr
def add_source_expr(self, e: z3.BoolRef) -> None: def add_source_expr(self, e: z3.BoolRef) -> None:

View File

@ -449,7 +449,7 @@ class CodeGen:
# This code-path used in Python < 3.9 # This code-path used in Python < 3.9
return origin_typename return origin_typename
return f'{origin_typename}[{",".join(args)}]' return f"{origin_typename}[{','.join(args)}]"
else: else:
# Bare type, such as `typing.Tuple` with no subscript # Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+ # This code-path used in Python 3.9+
@ -573,7 +573,7 @@ class CodeGen:
summary_str = parsed_stack_trace.get_summary_str() summary_str = parsed_stack_trace.get_summary_str()
else: else:
summary_str = "" summary_str = ""
body.append(f'\n {dim(f"# {summary_str}")}\n') body.append(f"\n {dim(f'# {summary_str}')}\n")
elif prev_stacktrace != "": elif prev_stacktrace != "":
prev_stacktrace = "" prev_stacktrace = ""
no_stacktrace_msg = "# No stacktrace found for following nodes" no_stacktrace_msg = "# No stacktrace found for following nodes"
@ -842,7 +842,7 @@ class _PyTreeCodeGen(CodeGen):
if len(has_annotation) > 0: if len(has_annotation) > 0:
fn_definition += "\n " + "".join(has_annotation) + "\n" fn_definition += "\n " + "".join(has_annotation) + "\n"
fn_definition += f""" fn_definition += f"""
{', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return fn_definition return fn_definition
def generate_output(self, output_args): def generate_output(self, output_args):
@ -1877,7 +1877,9 @@ class Graph:
# through `insert_pdb`: # through `insert_pdb`:
gm.graph.on_generate_code( gm.graph.on_generate_code(
lambda current_trans: ( lambda current_trans: (
lambda body: insert_pdb(current_trans(body) if current_trans else body) lambda body: insert_pdb(
current_trans(body) if current_trans else body
)
) )
) )
@ -1916,7 +1918,7 @@ class Graph:
@contextmanager @contextmanager
def _override_sym_repr( def _override_sym_repr(
override: Callable[["torch.types.PySymType"], str] override: Callable[["torch.types.PySymType"], str],
) -> Iterator[None]: ) -> Iterator[None]:
tmp = CodeGen._sym_repr tmp = CodeGen._sym_repr
try: try:

View File

@ -324,9 +324,9 @@ def _print_readable(
colored=False, colored=False,
): ):
graph = module.graph graph = module.graph
assert graph is not None and isinstance( assert graph is not None and isinstance(graph, torch.fx.Graph), (
graph, torch.fx.Graph "print_readable must be used on a module with a graph"
), "print_readable must be used on a module with a graph" )
verbose_python_code = graph.python_code( verbose_python_code = graph.python_code(
root_module="self", root_module="self",
@ -873,9 +873,9 @@ class {module_name}(torch.nn.Module):
for node in self.graph.nodes for node in self.graph.nodes
if "stack_trace" in node.meta if "stack_trace" in node.meta
} }
dict_without_graph[ dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = (
"_graphmodule_graph_node_meta_stack_trace" node_meta_stack_trace
] = node_meta_stack_trace )
generated_module_name = f"fx-generated._{exporter.get_unique_id()}" generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
python_code = self.recompile() python_code = self.recompile()

View File

@ -51,7 +51,9 @@ class Interpreter:
method equivalents). We could subclass Interpreter like so:: method equivalents). We could subclass Interpreter like so::
class NegSigmSwapInterpreter(Interpreter): class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: def call_function(
self, target: Target, args: Tuple, kwargs: Dict
) -> Any:
if target == torch.sigmoid: if target == torch.sigmoid:
return torch.neg(*args, **kwargs) return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs) return super().call_function(target, args, kwargs)
@ -405,7 +407,7 @@ class Interpreter:
for i, atom in enumerate(target_atoms): for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom): if not hasattr(attr_itr, atom):
raise RuntimeError( raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}" f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}"
) )
attr_itr = getattr(attr_itr, atom) attr_itr = getattr(attr_itr, atom)
return attr_itr return attr_itr
@ -468,14 +470,20 @@ class Transformer(Interpreter):
class NegSigmSwapXformer(Transformer): class NegSigmSwapXformer(Transformer):
def call_function( def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self,
target: "Target",
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
) -> Any: ) -> Any:
if target == torch.sigmoid: if target == torch.sigmoid:
return torch.neg(*args, **kwargs) return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs) return super().call_function(target, args, kwargs)
def call_method( def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self,
target: "Target",
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
) -> Any: ) -> Any:
if target == "neg": if target == "neg":
call_self, *args_tail = args call_self, *args_tail = args

View File

@ -514,9 +514,9 @@ class Node(_NodeBase):
idx (int): The index of the element in ``self.args`` to be inserted before. idx (int): The index of the element in ``self.args`` to be inserted before.
arg (Argument): The new argument value to insert into ``args`` arg (Argument): The new argument value to insert into ``args``
""" """
assert ( assert 0 <= idx <= len(self.args), (
0 <= idx <= len(self.args) "insert_args index must be between 0 and len(self.args)"
), "insert_args index must be between 0 and len(self.args)" )
args_left = self.args[:idx] args_left = self.args[:idx]
args_right = self.args[idx:] args_right = self.args[idx:]
@ -747,13 +747,13 @@ class Node(_NodeBase):
# Check if an impure module. # Check if an impure module.
if self.op == "call_module": if self.op == "call_module":
assert ( assert self.graph.owning_module is not None, (
self.graph.owning_module is not None "self.graph.owning_module not set for purity check"
), "self.graph.owning_module not set for purity check" )
target_mod = self.graph.owning_module.get_submodule(self.target) target_mod = self.graph.owning_module.get_submodule(self.target)
assert ( assert target_mod is not None, (
target_mod is not None f"Did not find expected submodule target {self.target}"
), f"Did not find expected submodule target {self.target}" )
return getattr(target_mod, "_is_impure", False) return getattr(target_mod, "_is_impure", False)
return False return False

View File

@ -770,9 +770,9 @@ class _MinimizerBase:
node_name = node.name node_name = node.name
if node_name is not None and isinstance(node_name, tuple): if node_name is not None and isinstance(node_name, tuple):
node_name = node_name[0] node_name = node_name[0]
assert node_name is not None and isinstance( assert node_name is not None and isinstance(node_name, str), (
node_name, str f"minimize: node_name: {node_name}"
), f"minimize: node_name: {node_name}" )
report.append(f"Add node: {node_name}") report.append(f"Add node: {node_name}")

View File

@ -93,9 +93,9 @@ def loop_pass(
predicate (Callable[Object, bool], optional): predicate (Callable[Object, bool], optional):
""" """
assert (n_iter is not None) ^ ( assert (n_iter is not None) ^ (predicate is not None), (
predicate is not None "Exactly one of `n_iter`or `predicate` must be specified."
), "Exactly one of `n_iter`or `predicate` must be specified." )
@wraps(base_pass) @wraps(base_pass)
def new_pass(source): def new_pass(source):

View File

@ -397,7 +397,9 @@ def insert_deferred_runtime_asserts(
nn_module_stack=node.meta.get("nn_module_stack"), nn_module_stack=node.meta.get("nn_module_stack"),
), ),
): ):
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] expr_to_proxy[sym_expr] = _sympy_interp(
expr_to_proxy, sym_expr
) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here # won't try DCE-ing tensor compute here
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
node.replace_all_uses_with(hash_node) node.replace_all_uses_with(hash_node)

View File

@ -199,9 +199,9 @@ def split_by_tags(
mx = max((c.order for c in upstream_components), default=0) mx = max((c.order for c in upstream_components), default=0)
# Expect the component for `node` has higher order then its upstream components. # Expect the component for `node` has higher order then its upstream components.
assert ( assert comp.order >= mx, (
comp.order >= mx f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
), f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" )
# Map a input of `node` to nodes in the component's graph. # Map a input of `node` to nodes in the component's graph.
def remap_func(x): def remap_func(x):

View File

@ -36,9 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList:
if indegree_map[n] == 0: if indegree_map[n] == 0:
candidates.put(n) candidates.put(n)
assert len(nodes) == len( assert len(nodes) == len(sorted_nodes), (
sorted_nodes "topological sorted nodes doesn't have same length as input nodes"
), "topological sorted nodes doesn't have same length as input nodes" )
return sorted_nodes return sorted_nodes
@ -127,13 +127,13 @@ def fuse_as_graphmodule(
# assumption: nodes are already sorted in topo order # assumption: nodes are already sorted in topo order
for node in nodes: for node in nodes:
assert ( assert node.graph.owning_module is gm, (
node.graph.owning_module is gm f"{node} doesn't belong to passed in graph module {gm._get_name()}"
), f"{node} doesn't belong to passed in graph module {gm._get_name()}" )
assert not node._erased, f"{node} has been removed from owning graph" assert not node._erased, f"{node} has been removed from owning graph"
assert ( assert node in gm.graph._find_nodes_lookup_table, (
node in gm.graph._find_nodes_lookup_table f"{node} is not found in graph module {gm._get_name()}"
), f"{node} is not found in graph module {gm._get_name()}" )
# validates partition doesn't introduce dependency circles in the graph # validates partition doesn't introduce dependency circles in the graph
assert validate_partition(nodes), "Invalid partition, found dependency cycles" assert validate_partition(nodes), "Invalid partition, found dependency cycles"

View File

@ -96,9 +96,9 @@ class SubgraphMatcher:
for node in pattern.nodes: for node in pattern.nodes:
if node.op != "output": if node.op != "output":
assert ( assert len(node.users) > 0, (
len(node.users) > 0 "SubgraphMatcher cannot be initialized with an pattern with dead code"
), "SubgraphMatcher cannot be initialized with an pattern with dead code" )
# TODO: assert pattern is a connected graph # TODO: assert pattern is a connected graph
@ -192,9 +192,9 @@ class SubgraphMatcher:
return non_overlapping_matches return non_overlapping_matches
def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
assert not ( assert not (isinstance(pn, Node) and isinstance(gn, Node)), (
isinstance(pn, Node) and isinstance(gn, Node) "pn and gn cannot both be Node"
), "pn and gn cannot both be Node" )
if isinstance(pn, Node) and not isinstance(gn, Node): if isinstance(pn, Node) and not isinstance(gn, Node):
if pn.op == "placeholder": if pn.op == "placeholder":

View File

@ -18,17 +18,17 @@ def _split_to_graph_and_name_node_map(
if n.op == "output": if n.op == "output":
assert gm._out_spec is not None assert gm._out_spec is not None
output = tree_unflatten(n.args[0], gm._out_spec) output = tree_unflatten(n.args[0], gm._out_spec)
assert isinstance( assert isinstance(output, tuple), (
output, tuple "Expecting the pattern graph to return a tuple"
), "Expecting the pattern graph to return a tuple" )
assert ( assert len(output) >= 2, (
len(output) >= 2 "Expecting the pattern graph to have at least two outputs"
), "Expecting the pattern graph to have at least two outputs" )
*out, name_node_map = output *out, name_node_map = output
flattened, out_spec = tree_flatten(out) flattened, out_spec = tree_flatten(out)
assert isinstance( assert isinstance(name_node_map, dict), (
name_node_map, dict "Expecting the input graph to have a dict output as the last element"
), "Expecting the input graph to have a dict output as the last element" )
n.args = (flattened,) n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
@ -53,12 +53,14 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
relu = F.relu(conv) relu = F.relu(conv)
return relu, {"conv": conv, "relu": relu} return relu, {"conv": conv, "relu": relu}
def target_graph(x, weight): def target_graph(x, weight):
conv = F.conv2d(x, weight) conv = F.conv2d(x, weight)
relu = F.relu(conv) relu = F.relu(conv)
relu *= 2 relu *= 2
return relu return relu
pattern_gm = export_for_training(pattern, example_inputs).module() pattern_gm = export_for_training(pattern, example_inputs).module()
target_gm = export_for_training(target_graph, example_inputs).module() target_gm = export_for_training(target_graph, example_inputs).module()
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)

View File

@ -654,9 +654,9 @@ class MetaProxy(Proxy):
meta_proxy = arg meta_proxy = arg
break break
assert ( assert meta_proxy is not None, (
meta_proxy is not None "No MetaProxy found in arguments, but one is expected."
), "No MetaProxy found in arguments, but one is expected." )
proxy = super().__torch_function__(orig_method, types, args, kwargs) proxy = super().__torch_function__(orig_method, types, args, kwargs)
with meta_proxy.fake_mode: with meta_proxy.fake_mode:
@ -739,14 +739,14 @@ for method in magic_methods:
return tracer.create_proxy("call_function", target, args, kwargs) return tracer.create_proxy("call_function", target, args, kwargs)
impl.__name__ = method impl.__name__ = method
as_magic = f'__{method.strip("_")}__' as_magic = f"__{method.strip('_')}__"
setattr(Proxy, as_magic, impl) setattr(Proxy, as_magic, impl)
_scope(method) _scope(method)
def _define_reflectable(orig_method_name): def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__' method_name = f"__r{orig_method_name.strip('_')}__"
def impl(self, rhs): def impl(self, rhs):
target = getattr(operator, orig_method_name) target = getattr(operator, orig_method_name)

View File

@ -307,9 +307,9 @@ def _replace_pattern(
elif callable(replacement): elif callable(replacement):
common_replacement_graph = symbolic_trace(replacement).graph common_replacement_graph = symbolic_trace(replacement).graph
else: else:
assert ( assert replacement_callback is not None, (
replacement_callback is not None "Must provide either a replacement GraphModule or a replacement callback"
), "Must provide either a replacement GraphModule or a replacement callback" )
common_replacement_graph = None common_replacement_graph = None
# As we progressively replace nodes, we'll need to keep track of how the match results should change # As we progressively replace nodes, we'll need to keep track of how the match results should change
@ -322,9 +322,9 @@ def _replace_pattern(
match, original_graph, pattern_graph match, original_graph, pattern_graph
) )
else: else:
assert ( assert common_replacement_graph is not None, (
common_replacement_graph is not None "Must provide either a replacement GraphModule or a replacement callback"
), "Must provide either a replacement GraphModule or a replacement callback" )
replacement_graph = common_replacement_graph replacement_graph = common_replacement_graph
replacement_placeholders = [ replacement_placeholders = [
n for n in replacement_graph.nodes if n.op == "placeholder" n for n in replacement_graph.nodes if n.op == "placeholder"

View File

@ -18,7 +18,15 @@ from torch.nn.modules.utils import (
_builtin_table: Optional[dict[int, str]] = None _builtin_table: Optional[dict[int, str]] = None
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950 _modules_containing_builtins = (
torch,
torch._C._nn,
torch._C._fft, # type: ignore[attr-defined]
torch._C._linalg, # type: ignore[attr-defined]
torch._C._nested, # type: ignore[attr-defined]
torch._C._sparse, # type: ignore[attr-defined]
torch._C._special, # type: ignore[attr-defined]
)
_builtin_ops = [ _builtin_ops = [
# Pairs of (function, op_name) # Pairs of (function, op_name)
@ -94,7 +102,10 @@ _builtin_ops = [
(torch.autograd.grad, "aten::grad"), (torch.autograd.grad, "aten::grad"),
(torch.autograd.backward, "aten::backward"), (torch.autograd.backward, "aten::backward"),
(torch._C._infer_size, "aten::_infer_size"), (torch._C._infer_size, "aten::_infer_size"),
(torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined] (
torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined]
"aten::_no_grad_embedding_renorm_",
),
(torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"), (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
(torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"), (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
(torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"), (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),

View File

@ -4,9 +4,9 @@ from torch._ops import OpOverload, OpOverloadPacket
def _register_decomposition(op: OpOverload, graph: torch._C.Graph): def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
assert not isinstance( assert not isinstance(op, OpOverloadPacket), (
op, OpOverloadPacket f"Must pass specific op overload, not overload packet, found {op}"
), f"Must pass specific op overload, not overload packet, found {op}" )
assert isinstance(op, OpOverload) assert isinstance(op, OpOverload)
torch._C._jit_register_decomposition_for_schema(op._schema, graph) torch._C._jit_register_decomposition_for_schema(op._schema, graph)

View File

@ -23,13 +23,13 @@ def check_decomposition_has_type_annotations(f):
inspect_empty = inspect._empty # type: ignore[attr-defined] inspect_empty = inspect._empty # type: ignore[attr-defined]
sig = inspect.signature(f) sig = inspect.signature(f)
for param in sig.parameters.values(): for param in sig.parameters.values():
assert ( assert param.annotation != inspect_empty, (
param.annotation != inspect_empty f"No signature on param {param.name} for function {f.name}"
), f"No signature on param {param.name} for function {f.name}" )
assert ( assert sig.return_annotation != inspect_empty, (
sig.return_annotation != inspect_empty f"No return annotation for function {f.name}"
), f"No return annotation for function {f.name}" )
def signatures_match(decomposition_sig, torch_op_sig): def signatures_match(decomposition_sig, torch_op_sig):
@ -75,9 +75,9 @@ def register_decomposition(
assert isinstance(aten_op, torch._ops.OpOverload) assert isinstance(aten_op, torch._ops.OpOverload)
# Need unique name for jit function serialization # Need unique name for jit function serialization
assert ( assert f.__name__ not in function_name_set, (
f.__name__ not in function_name_set f"Duplicated function name {f.__name__}"
), f"Duplicated function name {f.__name__}" )
function_name_set.add(f.__name__) function_name_set.add(f.__name__)
scripted_func = torch.jit.script(f) scripted_func = torch.jit.script(f)

View File

@ -588,9 +588,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
# recursively scripting them. # recursively scripting them.
for name, sub_concrete_type in concrete_type.get_modules(): for name, sub_concrete_type in concrete_type.get_modules():
orig_value = getattr(nn_module, name) orig_value = getattr(nn_module, name)
assert isinstance( assert isinstance(orig_value, Module), (
orig_value, Module f"Expected Module but got {type(orig_value)}"
), f"Expected Module but got {type(orig_value)}" )
module_type = sub_concrete_type.jit_type module_type = sub_concrete_type.jit_type
if isinstance(module_type, torch._C.InterfaceType): if isinstance(module_type, torch._C.InterfaceType):
# use the interface inference rule to compile the module # use the interface inference rule to compile the module

View File

@ -318,11 +318,11 @@ class ScriptMeta(type):
else: else:
return infer_methods_to_compile(module) return infer_methods_to_compile(module)
self.__dict__[ self.__dict__["_actual_script_module"] = (
"_actual_script_module" torch.jit._recursive.create_script_module(
] = torch.jit._recursive.create_script_module(
self, make_stubs, share_types=not added_methods_in_init self, make_stubs, share_types=not added_methods_in_init
) )
)
# Delete the Python attributes that now shadow the ScriptModule # Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find # ones, so that __getattr__ and __setattr__ will properly find

View File

@ -280,15 +280,15 @@ def max_pool2d(
dilation: list[int], dilation: list[int],
ceil_mode: bool, ceil_mode: bool,
): ):
assert ( assert len(kernel_size) == 1 or len(kernel_size) == 2, (
len(kernel_size) == 1 or len(kernel_size) == 2 "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" )
kH = kernel_size[0] kH = kernel_size[0]
kW = kH if len(kernel_size) == 1 else kernel_size[1] kW = kH if len(kernel_size) == 1 else kernel_size[1]
assert ( assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, (
len(stride) == 0 or len(stride) == 1 or len(stride) == 2 "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" )
dH = kH if len(stride) == 0 else stride[0] dH = kH if len(stride) == 0 else stride[0]
if len(stride) == 0: if len(stride) == 0:
dW = kW dW = kW
@ -297,15 +297,15 @@ def max_pool2d(
else: else:
dW = stride[1] dW = stride[1]
assert ( assert len(padding) == 1 or len(padding) == 2, (
len(padding) == 1 or len(padding) == 2 "max_pool2d: padding must either be a single int, or a tuple of two ints"
), "max_pool2d: padding must either be a single int, or a tuple of two ints" )
padH = padding[0] padH = padding[0]
padW = padH if len(padding) == 1 else padding[1] padW = padH if len(padding) == 1 else padding[1]
assert ( assert len(dilation) == 1 or len(dilation) == 2, (
len(dilation) == 1 or len(dilation) == 2 "max_pool2d: dilation must be either a single int, or a tuple of two ints"
), "max_pool2d: dilation must be either a single int, or a tuple of two ints" )
dilationH = dilation[0] dilationH = dilation[0]
dilationW = dilationH if len(dilation) == 1 else dilation[1] dilationW = dilationH if len(dilation) == 1 else dilation[1]
@ -367,17 +367,17 @@ def upsample_nearest2d(
assert 0, "Either output_size or scale_factors must be presented" assert 0, "Either output_size or scale_factors must be presented"
if output_size is not None: if output_size is not None:
assert ( assert scale_factors is None, (
scale_factors is None "Must specify exactly one of output_size and scale_factors"
), "Must specify exactly one of output_size and scale_factors" )
assert len(output_size) == 2 assert len(output_size) == 2
out.append(output_size[0]) out.append(output_size[0])
out.append(output_size[1]) out.append(output_size[1])
if scale_factors is not None: if scale_factors is not None:
assert ( assert output_size is None, (
output_size is None "Must specify exactly one of output_size and scale_factors"
), "Must specify exactly one of output_size and scale_factors" )
assert len(scale_factors) == 2 assert len(scale_factors) == 2
out.append(int(input[2] * scale_factors[0])) out.append(int(input[2] * scale_factors[0]))
out.append(int(input[3] * scale_factors[1])) out.append(int(input[3] * scale_factors[1]))
@ -540,9 +540,9 @@ def check_cat_shape_except_dim(
assert first_dims == second_dims, "Tensors must have same number of dimensions" assert first_dims == second_dims, "Tensors must have same number of dimensions"
for dim in range(0, first_dims): for dim in range(0, first_dims):
if dim != dimension: if dim != dimension:
assert ( assert first[dim] == second[dim], (
first[dim] == second[dim] "Sizes of tensors must match except in dimension"
), "Sizes of tensors must match except in dimension" )
def cat(tensors: list[list[int]], dim: int): def cat(tensors: list[list[int]], dim: int):
@ -1088,9 +1088,9 @@ def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
if len(self) == 0: if len(self) == 0:
result: list[int] = [] result: list[int] = []
else: else:
assert ( assert k <= self[dim], (
k <= self[dim] f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
), f"k ({k}) is too big for dimension {dim} of size {self[dim]}" )
result = _copy(self) result = _copy(self)
result[dim] = k result[dim] = k
return result, result return result, result

View File

@ -1205,7 +1205,10 @@ def trace_module(
# Trace specific methods on a module (specified in `inputs`), constructs # Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight} inputs = {
"forward": example_forward_input,
"weighted_kernel_sum": example_weight,
}
module = torch.jit.trace_module(n, inputs) module = torch.jit.trace_module(n, inputs)
""" """

View File

@ -309,14 +309,14 @@ defined as ``prod(x[:i])``.""",
operation_args, operation_kwargs = args_and_kwargs[func.__name__] operation_args, operation_kwargs = args_and_kwargs[func.__name__]
arg_declarations = [ arg_declarations = [
"\n ".join( "\n ".join(
argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines()
) )
for a in operation_args for a in operation_args
] ]
kwarg_declarations = [ kwarg_declarations = [
"\n ".join( "\n ".join(
argument_declarations.get( argument_declarations.get(
a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD."
) )
.format(default=a.split("=", 1)[1]) .format(default=a.split("=", 1)[1])
.splitlines() .splitlines()
@ -745,9 +745,9 @@ def _sparse_csr_segment_reduction_helper(
) -> Tensor: ) -> Tensor:
# Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
# FIXME: when dense dimensions are implemented for CSR tensors # FIXME: when dense dimensions are implemented for CSR tensors
assert ( assert keepdim, (
keepdim "reduction operations on CSR tensors with keepdim=False is unsupported"
), "reduction operations on CSR tensors with keepdim=False is unsupported" )
reduce = op.__name__ reduce = op.__name__
valid_reductions = ["sum", "prod", "mean", "amax", "amin"] valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
if reduce not in valid_reductions: if reduce not in valid_reductions:
@ -781,9 +781,9 @@ def _sparse_csr_segment_reduction_helper(
) )
new_shape = [1, mask_input.size(1)] new_shape = [1, mask_input.size(1)]
else: else:
assert ( assert dims[0] == 1, (
dims[0] == 1 "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." )
# all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
# except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
new_crow_indices = torch.cat( new_crow_indices = torch.cat(
@ -1598,9 +1598,9 @@ def _std_var(
mask: Optional[Tensor], mask: Optional[Tensor],
take_sqrt: Optional[bool], take_sqrt: Optional[bool],
) -> Tensor: ) -> Tensor:
assert ( assert unbiased is None or correction_opt is None, (
unbiased is None or correction_opt is None "Only one of unbiased and correction may be given"
), "Only one of unbiased and correction may be given" )
correction = 1.0 correction = 1.0
if unbiased is not None: if unbiased is not None:
correction = 1.0 if unbiased else 0.0 correction = 1.0 if unbiased else 0.0
@ -1636,7 +1636,11 @@ def _std_var(
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
else: else:
total = sum( total = sum(
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] x * x.conj(),
dim,
keepdim=keepdim,
dtype=compute_dtype,
mask=inmask, # type: ignore[possibly-undefined]
) )
if not keepdim: if not keepdim:
count = count.reshape(total.shape) count = count.reshape(total.shape)

View File

@ -25,7 +25,7 @@ def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
>>> # xdoctest: +SKIP >>> # xdoctest: +SKIP
>>> from torch.masked import MaskedTensor >>> from torch.masked import MaskedTensor
>>> data = torch.arange(6).reshape(2,3) >>> data = torch.arange(6).reshape(2, 3)
>>> mask = torch.tensor([[True, False, False], [True, True, False]]) >>> mask = torch.tensor([[True, False, False], [True, True, False]])
>>> mt = MaskedTensor(data, mask) >>> mt = MaskedTensor(data, mask)
>>> is_masked_tensor(mt) >>> is_masked_tensor(mt)

View File

@ -5,6 +5,7 @@ Metal is Apple's API for programming metal GPU (graphics processor unit). Using
performance can be achieved, by running work on the metal GPU(s). performance can be achieved, by running work on the metal GPU(s).
See https://developer.apple.com/documentation/metalperformanceshaders for more details. See https://developer.apple.com/documentation/metalperformanceshaders for more details.
""" """
from typing import Union from typing import Union
import torch import torch

View File

@ -198,7 +198,7 @@ def snapshot() -> dict[str, Any]:
def attach_out_of_memory_observer( def attach_out_of_memory_observer(
observer: Callable[[int, int, int, int], None] observer: Callable[[int, int, int, int], None],
) -> None: ) -> None:
r"""Attach an out-of-memory observer to MTIA memory allocator""" r"""Attach an out-of-memory observer to MTIA memory allocator"""
torch._C._mtia_attachOutOfMemoryObserver(observer) torch._C._mtia_attachOutOfMemoryObserver(observer)

View File

@ -14,6 +14,7 @@ memory.
Because of the similarity of APIs we do not document most of this package Because of the similarity of APIs we do not document most of this package
contents, and we recommend referring to very good docs of the original module. contents, and we recommend referring to very good docs of the original module.
""" """
import multiprocessing import multiprocessing
import sys import sys