mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][PYFMT] migrate PYFMT for torch/[e-n]*/
to ruff format
(#144553)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144553 Approved by: https://github.com/ezyang ghstack dependencies: #144551
This commit is contained in:
committed by
PyTorch MergeBot
parent
95cb42c45d
commit
2e0e08588e
@ -60,7 +60,6 @@ USE_BLACK_FILELIST = re.compile(
|
||||
"torch/[b-c]*/**",
|
||||
# torch/d*/**
|
||||
# torch/[e-m]*/**
|
||||
"torch/[e-m]*/**",
|
||||
# torch/optim/**
|
||||
# torch/[p-z]*/**
|
||||
"torch/[p-z]*/**",
|
||||
|
@ -358,22 +358,24 @@ def save(
|
||||
import torch
|
||||
import io
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 10
|
||||
|
||||
|
||||
ep = torch.export.export(MyModule(), (torch.randn(5),))
|
||||
|
||||
# Save to file
|
||||
torch.export.save(ep, 'exported_program.pt2')
|
||||
torch.export.save(ep, "exported_program.pt2")
|
||||
|
||||
# Save to io.BytesIO buffer
|
||||
buffer = io.BytesIO()
|
||||
torch.export.save(ep, buffer)
|
||||
|
||||
# Save with extra files
|
||||
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
|
||||
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
|
||||
extra_files = {"foo.txt": b"bar".decode("utf-8")}
|
||||
torch.export.save(ep, "exported_program.pt2", extra_files=extra_files)
|
||||
|
||||
"""
|
||||
if not isinstance(ep, ExportedProgram):
|
||||
@ -427,18 +429,18 @@ def load(
|
||||
import io
|
||||
|
||||
# Load ExportedProgram from file
|
||||
ep = torch.export.load('exported_program.pt2')
|
||||
ep = torch.export.load("exported_program.pt2")
|
||||
|
||||
# Load ExportedProgram from io.BytesIO object
|
||||
with open('exported_program.pt2', 'rb') as f:
|
||||
with open("exported_program.pt2", "rb") as f:
|
||||
buffer = io.BytesIO(f.read())
|
||||
buffer.seek(0)
|
||||
ep = torch.export.load(buffer)
|
||||
|
||||
# Load with extra files.
|
||||
extra_files = {'foo.txt': ''} # values will be replaced with data
|
||||
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
|
||||
print(extra_files['foo.txt'])
|
||||
extra_files = {"foo.txt": ""} # values will be replaced with data
|
||||
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
|
||||
print(extra_files["foo.txt"])
|
||||
print(ep(torch.randn(5)))
|
||||
"""
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
@ -572,24 +574,29 @@ def register_dataclass(
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDataClass:
|
||||
feature: torch.Tensor
|
||||
bias: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputDataClass:
|
||||
res: torch.Tensor
|
||||
|
||||
|
||||
torch.export.register_dataclass(InputDataClass)
|
||||
torch.export.register_dataclass(OutputDataClass)
|
||||
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x: InputDataClass) -> OutputDataClass:
|
||||
res = x.feature + x.bias
|
||||
return OutputDataClass(res=res)
|
||||
|
||||
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), ))
|
||||
|
||||
ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),))
|
||||
print(ep)
|
||||
|
||||
"""
|
||||
|
@ -43,7 +43,7 @@ def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str])
|
||||
continue
|
||||
|
||||
res += f"""
|
||||
File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}""" # type: ignore[index]
|
||||
File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index]
|
||||
|
||||
res += f"\n {stack[-1]['loc']}"
|
||||
return res
|
||||
@ -327,12 +327,12 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||
# We don't want to log all expression_created logs, only
|
||||
# the ones that are relevant to the
|
||||
# guards/propagate_real_tensor
|
||||
self.expression_created_logs[
|
||||
metadata[key]["result_id"]
|
||||
] = ExpressionCreatedNode(
|
||||
metadata[key]["result_id"],
|
||||
metadata[key].get("argument_ids", []),
|
||||
record,
|
||||
self.expression_created_logs[metadata[key]["result_id"]] = (
|
||||
ExpressionCreatedNode(
|
||||
metadata[key]["result_id"],
|
||||
metadata[key].get("argument_ids", []),
|
||||
record,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@ -374,10 +374,13 @@ def draft_export(
|
||||
|
||||
capture_structured_log = CaptureStructuredTrace()
|
||||
|
||||
with torch._functorch.config.patch(
|
||||
fake_tensor_propagate_real_tensors=True,
|
||||
generate_fake_kernels_from_real_mismatches=True,
|
||||
), capture_structured_log:
|
||||
with (
|
||||
torch._functorch.config.patch(
|
||||
fake_tensor_propagate_real_tensors=True,
|
||||
generate_fake_kernels_from_real_mismatches=True,
|
||||
),
|
||||
capture_structured_log,
|
||||
):
|
||||
try:
|
||||
new_shapes = None
|
||||
ep = _export(
|
||||
@ -424,10 +427,10 @@ def draft_export(
|
||||
continue
|
||||
|
||||
elif log_name == "propagate_real_tensors_provenance":
|
||||
log_contents[
|
||||
"occurrences"
|
||||
] = capture_structured_log.log_record.get_log_count(
|
||||
(log_name, log_contents)
|
||||
log_contents["occurrences"] = (
|
||||
capture_structured_log.log_record.get_log_count(
|
||||
(log_name, log_contents)
|
||||
)
|
||||
)
|
||||
|
||||
failure_type = FailureType.DATA_DEPENDENT_ERROR
|
||||
|
@ -26,9 +26,9 @@ def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
|
||||
if user.op == "output":
|
||||
continue
|
||||
|
||||
assert (
|
||||
user.op == "call_function" and user.target == operator.getitem
|
||||
), f"Expected getitem node as user for {node}, instead got {user}"
|
||||
assert user.op == "call_function" and user.target == operator.getitem, (
|
||||
f"Expected getitem node as user for {node}, instead got {user}"
|
||||
)
|
||||
getitem_users.update(list(user.users.keys()))
|
||||
return getitem_users
|
||||
|
||||
@ -63,9 +63,9 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
||||
log.debug("Trying to remove pytrees for module call %s", curr_module_node)
|
||||
|
||||
curr_module_users = list(curr_module_node.users.keys())
|
||||
assert (
|
||||
len(curr_module_users) == 1
|
||||
), f"Expected only one user for module node, instead got {list(curr_module_users)}"
|
||||
assert len(curr_module_users) == 1, (
|
||||
f"Expected only one user for module node, instead got {list(curr_module_users)}"
|
||||
)
|
||||
flatten_node = curr_module_users[0]
|
||||
assert (
|
||||
flatten_node.op == "call_function"
|
||||
|
@ -268,9 +268,9 @@ def _extract_fake_inputs(gm, args, kwargs):
|
||||
|
||||
if detected_fake_mode:
|
||||
if detected_shape_env:
|
||||
assert (
|
||||
detected_shape_env is detected_fake_mode.shape_env
|
||||
), "Detected shape env does not match fake mode's shape env"
|
||||
assert detected_shape_env is detected_fake_mode.shape_env, (
|
||||
"Detected shape env does not match fake mode's shape env"
|
||||
)
|
||||
fake_mode = detected_fake_mode
|
||||
elif detected_shape_env:
|
||||
fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True)
|
||||
@ -864,13 +864,19 @@ def _export_to_aten_ir(
|
||||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||
with torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(), custom_triton_ops_decomposition_ctx(): # type: ignore[attr-defined]
|
||||
with (
|
||||
torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
),
|
||||
grad_safe_guard,
|
||||
_ignore_backend_decomps(),
|
||||
_compiling_state_context(),
|
||||
custom_triton_ops_decomposition_ctx(),
|
||||
):
|
||||
gm, graph_signature = transform(aot_export_module)(
|
||||
mod,
|
||||
fake_args,
|
||||
@ -1229,9 +1235,9 @@ def _get_module_call_graph(
|
||||
"""
|
||||
gm: torch.fx.GraphModule = export_artifact.aten.gm
|
||||
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
|
||||
module_call_specs: dict[
|
||||
str, dict[str, TreeSpec]
|
||||
] = export_artifact.module_call_specs
|
||||
module_call_specs: dict[str, dict[str, TreeSpec]] = (
|
||||
export_artifact.module_call_specs
|
||||
)
|
||||
in_spec: TreeSpec = export_artifact.in_spec
|
||||
out_spec: TreeSpec = export_artifact.out_spec
|
||||
|
||||
@ -1365,7 +1371,8 @@ def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
|
||||
).module()
|
||||
|
||||
elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
|
||||
traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator]
|
||||
traced_callable.owner(), # type: ignore[operator]
|
||||
(torch._C.ScriptModule, torch.nn.Module),
|
||||
):
|
||||
with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
|
||||
return _export(
|
||||
@ -1430,9 +1437,9 @@ def _strict_export(
|
||||
attr = getattr(gm_torch_level, node.target)
|
||||
# Checks if it is not a HigherOrderOp branch or a module
|
||||
if not isinstance(attr, torch.nn.Module):
|
||||
assert (
|
||||
dynamo_fake_mode is not None
|
||||
), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
|
||||
assert dynamo_fake_mode is not None, (
|
||||
"Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
|
||||
)
|
||||
node.meta["val"] = dynamo_fake_mode.from_tensor(
|
||||
attr, static_shapes=True
|
||||
)
|
||||
@ -1749,13 +1756,17 @@ def _export_to_aten_ir_make_fx(
|
||||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||
with torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined]
|
||||
with (
|
||||
torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
),
|
||||
_ignore_backend_decomps(),
|
||||
_compiling_state_context(),
|
||||
):
|
||||
gm, graph_signature = transform(_make_fx_helper)(
|
||||
mod,
|
||||
fake_args,
|
||||
@ -1944,22 +1955,27 @@ def _non_strict_export(
|
||||
# We also need to attach dynamo configs as these will be used in HOOs that
|
||||
# use torch.compile, like cond
|
||||
dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
|
||||
dynamo_config[
|
||||
"do_not_emit_runtime_asserts"
|
||||
] = False # We want to emit runtime asserts
|
||||
dynamo_config["do_not_emit_runtime_asserts"] = (
|
||||
False # We want to emit runtime asserts
|
||||
)
|
||||
|
||||
with fake_mode, _NonStrictTorchFunctionHandler(), tracing(
|
||||
tx
|
||||
), torch._dynamo.config.patch(dynamo_config):
|
||||
with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
|
||||
patched_mod,
|
||||
new_fake_args,
|
||||
new_fake_kwargs,
|
||||
new_fake_constant_attrs,
|
||||
map_fake_to_real,
|
||||
), _fakify_module_inputs(
|
||||
fake_args, fake_kwargs, fake_mode
|
||||
), _override_builtin_ops():
|
||||
with (
|
||||
fake_mode,
|
||||
_NonStrictTorchFunctionHandler(),
|
||||
tracing(tx),
|
||||
torch._dynamo.config.patch(dynamo_config),
|
||||
):
|
||||
with (
|
||||
_fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
|
||||
patched_mod,
|
||||
new_fake_args,
|
||||
new_fake_kwargs,
|
||||
new_fake_constant_attrs,
|
||||
map_fake_to_real,
|
||||
),
|
||||
_fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
|
||||
_override_builtin_ops(),
|
||||
):
|
||||
aten_export_artifact = _to_aten_func( # type: ignore[operator]
|
||||
patched_mod,
|
||||
new_fake_args,
|
||||
|
@ -666,7 +666,7 @@ class ShapesCollection:
|
||||
|
||||
Example::
|
||||
|
||||
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})
|
||||
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
|
||||
|
||||
dim = torch.export.Dim(...)
|
||||
dynamic_shapes = torch.export.ShapesCollection()
|
||||
@ -682,7 +682,7 @@ class ShapesCollection:
|
||||
|
||||
Example::
|
||||
|
||||
args = ({"x": tensor_x, "others": [int_x, int_y]})
|
||||
args = {"x": tensor_x, "others": [int_x, int_y]}
|
||||
# Wrap all ints with _IntWrapper
|
||||
mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args)
|
||||
|
||||
@ -700,18 +700,18 @@ class ShapesCollection:
|
||||
self._shapes = {}
|
||||
|
||||
def __setitem__(self, t, shape):
|
||||
assert isinstance(
|
||||
t, (torch.Tensor, _IntWrapper)
|
||||
), f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}"
|
||||
assert isinstance(t, (torch.Tensor, _IntWrapper)), (
|
||||
f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}"
|
||||
)
|
||||
|
||||
# TODO(avik): check that shape is indeed a Shape
|
||||
|
||||
t_id = id(t)
|
||||
if t_id in self._shapes:
|
||||
_shape = self._shapes[t_id]
|
||||
assert (
|
||||
shape == _shape
|
||||
), f"Shapes assigned to input do not match: expected {_shape}, got {shape}"
|
||||
assert shape == _shape, (
|
||||
f"Shapes assigned to input do not match: expected {_shape}, got {shape}"
|
||||
)
|
||||
else:
|
||||
self._shapes[id(t)] = shape
|
||||
|
||||
@ -766,7 +766,7 @@ class AdditionalInputs:
|
||||
|
||||
Example::
|
||||
|
||||
args0, kwargs0 = ... # example inputs for export
|
||||
args0, kwargs0 = ... # example inputs for export
|
||||
|
||||
# other representative inputs that the exported program will run on
|
||||
dynamic_shapes = torch.export.AdditionalInputs()
|
||||
@ -786,9 +786,9 @@ class AdditionalInputs:
|
||||
"""
|
||||
|
||||
assert type(args) is tuple, f"Representative args {args} must be a tuple"
|
||||
assert (
|
||||
kwargs is None or type(kwargs) is dict
|
||||
), f"Representative kwargs {kwargs} must be None or a dict"
|
||||
assert kwargs is None or type(kwargs) is dict, (
|
||||
f"Representative kwargs {kwargs} must be None or a dict"
|
||||
)
|
||||
self._examples.append((args, kwargs))
|
||||
|
||||
def dynamic_shapes(self, m, args, kwargs=None):
|
||||
@ -1075,7 +1075,8 @@ def _process_dynamic_shapes(
|
||||
i,
|
||||
dim.__name__,
|
||||
StrictMinMaxConstraint(
|
||||
vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined]
|
||||
vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined]
|
||||
warn_only=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -1085,7 +1086,8 @@ def _process_dynamic_shapes(
|
||||
i,
|
||||
dim.__name__,
|
||||
StrictMinMaxConstraint(
|
||||
vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined]
|
||||
vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined]
|
||||
warn_only=False,
|
||||
),
|
||||
)
|
||||
return constraint
|
||||
@ -1161,7 +1163,7 @@ def _process_dynamic_shapes(
|
||||
|
||||
|
||||
def _get_dim_name_mapping(
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
):
|
||||
name_to_dim = {}
|
||||
for dim in tree_flatten(
|
||||
|
@ -137,16 +137,11 @@ class _ExportPackage:
|
||||
"decoder": ExportMethod(
|
||||
overloads={
|
||||
"prefill": ExportedProgram(...),
|
||||
"decode": ExportedProgram(...)
|
||||
"decode": ExportedProgram(...),
|
||||
},
|
||||
fallbacks=[]
|
||||
fallbacks=[],
|
||||
),
|
||||
"encoder": ExportMethod(
|
||||
overloads={},
|
||||
fallbacks=[
|
||||
ExportedProgram(...)
|
||||
]
|
||||
)
|
||||
"encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]),
|
||||
},
|
||||
)
|
||||
```
|
||||
@ -212,15 +207,18 @@ class _ExportPackage:
|
||||
```
|
||||
package = ExportPackage()
|
||||
|
||||
|
||||
def prefill(x, xa, kv_cache):
|
||||
assert x.shape[1] == 3
|
||||
assert kv_cache == {}
|
||||
|
||||
|
||||
def decode(x, xa, kv_cache):
|
||||
assert x.shape[1] > 1
|
||||
assert len(kv_cache) > 0
|
||||
return {...} # dynamic shape specs here
|
||||
|
||||
|
||||
exporter = (
|
||||
package.exporter(decoder)
|
||||
.define_overload("prefill", prefill)
|
||||
|
@ -272,7 +272,7 @@ def _override_composite_implicit_decomp(cia_ops_to_callable):
|
||||
|
||||
|
||||
def _split_decomp_table_to_cia_and_python_decomp(
|
||||
decomp_table: dict[torch._ops.OperatorBase, Callable]
|
||||
decomp_table: dict[torch._ops.OperatorBase, Callable],
|
||||
) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]:
|
||||
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
|
||||
cia_ops_to_callable = {}
|
||||
@ -443,9 +443,14 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
|
||||
tx = TracingContext(fake_mode)
|
||||
|
||||
with fake_mode, _override_composite_implicit_decomp(
|
||||
cia_to_decomp,
|
||||
), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx):
|
||||
with (
|
||||
fake_mode,
|
||||
_override_composite_implicit_decomp(
|
||||
cia_to_decomp,
|
||||
),
|
||||
_enable_graph_inputs_of_type_nn_module(ep.example_inputs),
|
||||
tracing(tx),
|
||||
):
|
||||
retracing_args_unwrapped = pytree.tree_unflatten(
|
||||
retracing_args, mod._in_spec
|
||||
)
|
||||
@ -573,9 +578,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
if decompose_custom_triton_ops
|
||||
else _disable_custom_triton_op_functional_decomposition
|
||||
)
|
||||
with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp(
|
||||
cia_to_decomp
|
||||
), custom_triton_ops_decomposition_ctx():
|
||||
with (
|
||||
_ignore_backend_decomps(),
|
||||
fake_mode,
|
||||
_override_composite_implicit_decomp(cia_to_decomp),
|
||||
custom_triton_ops_decomposition_ctx(),
|
||||
):
|
||||
gm, graph_signature = aot_export_module(
|
||||
ep.graph_module,
|
||||
fake_args,
|
||||
@ -1514,9 +1522,9 @@ class ExportedProgram:
|
||||
if node.op != "placeholder":
|
||||
break
|
||||
|
||||
assert i < len(
|
||||
old_signature.input_specs
|
||||
), "Number of inputs changed after transformation"
|
||||
assert i < len(old_signature.input_specs), (
|
||||
"Number of inputs changed after transformation"
|
||||
)
|
||||
old_input_spec = old_signature.input_specs[i]
|
||||
arg = (
|
||||
old_input_spec.arg
|
||||
@ -1539,9 +1547,9 @@ class ExportedProgram:
|
||||
|
||||
new_output_specs = []
|
||||
for i, node in enumerate(output_node.args[0]):
|
||||
assert i < len(
|
||||
old_signature.output_specs
|
||||
), "Number of outputs changed after transformation"
|
||||
assert i < len(old_signature.output_specs), (
|
||||
"Number of outputs changed after transformation"
|
||||
)
|
||||
old_output_spec = old_signature.output_specs[i]
|
||||
arg = (
|
||||
old_output_spec.arg
|
||||
@ -1599,9 +1607,9 @@ class ExportedProgram:
|
||||
# TODO: remove this
|
||||
@final
|
||||
def _validate(self):
|
||||
assert (
|
||||
len(self.verifiers) > 0
|
||||
), "ExportedProgram must have at least one verifier."
|
||||
assert len(self.verifiers) > 0, (
|
||||
"ExportedProgram must have at least one verifier."
|
||||
)
|
||||
for v in self.verifiers:
|
||||
v().check(self)
|
||||
|
||||
|
@ -95,9 +95,9 @@ class InputSpec:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.kind == InputKind.BUFFER:
|
||||
assert (
|
||||
self.persistent is not None
|
||||
), "Failed to specify persistent flag on BUFFER."
|
||||
assert self.persistent is not None, (
|
||||
"Failed to specify persistent flag on BUFFER."
|
||||
)
|
||||
assert isinstance(
|
||||
self.arg,
|
||||
(
|
||||
@ -187,15 +187,17 @@ class ExportGraphSignature:
|
||||
self.my_parameter = nn.Parameter(torch.tensor(2.0))
|
||||
|
||||
# Define two buffers
|
||||
self.register_buffer('my_buffer1', torch.tensor(3.0))
|
||||
self.register_buffer('my_buffer2', torch.tensor(4.0))
|
||||
self.register_buffer("my_buffer1", torch.tensor(3.0))
|
||||
self.register_buffer("my_buffer2", torch.tensor(4.0))
|
||||
|
||||
def forward(self, x1, x2):
|
||||
# Use the parameter, buffers, and both inputs in the forward method
|
||||
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
|
||||
output = (
|
||||
x1 + self.my_parameter
|
||||
) * self.my_buffer1 + x2 * self.my_buffer2
|
||||
|
||||
# Mutate one of the buffers (e.g., increment it by 1)
|
||||
self.my_buffer2.add_(1.0) # In-place addition
|
||||
self.my_buffer2.add_(1.0) # In-place addition
|
||||
|
||||
return output
|
||||
|
||||
@ -520,9 +522,9 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
|
||||
# For const outputs we just directly return this
|
||||
return ConstantArgument(name="", value=node)
|
||||
|
||||
assert (
|
||||
"val" in node.meta
|
||||
), f"{node} is not a constant or a node with a 'val' metadata field"
|
||||
assert "val" in node.meta, (
|
||||
f"{node} is not a constant or a node with a 'val' metadata field"
|
||||
)
|
||||
val = node.meta["val"]
|
||||
if node.name in token_names:
|
||||
return TokenArgument(name=node.name)
|
||||
@ -565,9 +567,21 @@ def _convert_to_export_graph_signature(
|
||||
user_outputs = set(graph_signature.user_outputs)
|
||||
buffer_mutations = graph_signature.buffers_to_mutate
|
||||
user_input_mutations = graph_signature.user_inputs_to_mutate
|
||||
grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr]
|
||||
grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr]
|
||||
loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr]
|
||||
grad_params = (
|
||||
graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
|
||||
if is_joint
|
||||
else {}
|
||||
)
|
||||
grad_user_inputs = (
|
||||
graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr]
|
||||
if is_joint
|
||||
else {}
|
||||
)
|
||||
loss_output = (
|
||||
graph_signature.backward_signature.loss_output # type: ignore[union-attr]
|
||||
if is_joint
|
||||
else None
|
||||
)
|
||||
input_tokens = graph_signature.input_tokens
|
||||
output_tokens = graph_signature.output_tokens
|
||||
|
||||
|
@ -155,9 +155,9 @@ class PT2ArchiveReader:
|
||||
|
||||
def __init__(self, archive_path_or_buffer: FileLike):
|
||||
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
|
||||
assert (
|
||||
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
|
||||
), "Invalid archive format"
|
||||
assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, (
|
||||
"Invalid archive format"
|
||||
)
|
||||
|
||||
def __enter__(self) -> "PT2ArchiveReader":
|
||||
return self
|
||||
|
@ -104,9 +104,9 @@ def _assign_attr(
|
||||
assert isinstance(from_obj, torch.Tensor)
|
||||
to_module.register_buffer(field, from_obj, persistent=persistent)
|
||||
elif attr_kind == _AttrKind.CONSTANT:
|
||||
assert not isinstance(
|
||||
from_obj, FakeScriptObject
|
||||
), "FakeScriptObject should only exist during tracing."
|
||||
assert not isinstance(from_obj, FakeScriptObject), (
|
||||
"FakeScriptObject should only exist during tracing."
|
||||
)
|
||||
assert isinstance(
|
||||
from_obj,
|
||||
(
|
||||
@ -461,9 +461,9 @@ class UnflattenedModule(torch.nn.Module):
|
||||
# add constants that are aliased and don't appear in graph signature
|
||||
for const_name, const in export_module.constants.items():
|
||||
if const_name not in consts_targets:
|
||||
assert (
|
||||
id(const) in consts_map
|
||||
), "Constants should be either aliased or appear in graph signature"
|
||||
assert id(const) in consts_map, (
|
||||
"Constants should be either aliased or appear in graph signature"
|
||||
)
|
||||
ph_name, _ = consts_map[id(const)][0]
|
||||
add_to_consts_map(id(const), ph_name, const_name)
|
||||
added_params_buffers.add(s.target)
|
||||
@ -1041,9 +1041,9 @@ class _ModuleFrame:
|
||||
|
||||
if arg.name in self.seen_nodes:
|
||||
flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta)
|
||||
self.node_to_placeholder[
|
||||
self.seen_nodes[arg.name]
|
||||
] = flat_arg_node
|
||||
self.node_to_placeholder[self.seen_nodes[arg.name]] = (
|
||||
flat_arg_node
|
||||
)
|
||||
|
||||
with self.parent.graph.inserting_before(self.parent_call_module):
|
||||
input_nodes: list[Optional[torch.fx.Node]] = []
|
||||
@ -1125,8 +1125,7 @@ class _ModuleFrame:
|
||||
if x in self.node_to_placeholder:
|
||||
return self.node_to_placeholder[x]
|
||||
elif (
|
||||
x.op == "placeholder"
|
||||
or self.module_call_graph.get(self.fqn) is None
|
||||
x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None
|
||||
# allow placeholder creation if we are not preserving module call signature
|
||||
):
|
||||
self.add_placeholder(x)
|
||||
|
@ -82,9 +82,7 @@ Example:
|
||||
>>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
|
||||
>>> torch.fft.fft(t)
|
||||
tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ifft = _add_docstr(
|
||||
@ -125,9 +123,7 @@ Example:
|
||||
>>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
|
||||
>>> torch.fft.ifft(t)
|
||||
tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
fft2 = _add_docstr(
|
||||
@ -188,9 +184,7 @@ Example:
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
|
||||
>>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ifft2 = _add_docstr(
|
||||
@ -243,9 +237,7 @@ Example:
|
||||
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
|
||||
>>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
fftn = _add_docstr(
|
||||
@ -305,9 +297,7 @@ Example:
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
|
||||
>>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ifftn = _add_docstr(
|
||||
@ -359,9 +349,7 @@ Example:
|
||||
>>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
|
||||
>>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
rfft = _add_docstr(
|
||||
@ -417,9 +405,7 @@ Example:
|
||||
Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted.
|
||||
At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair,
|
||||
and therefore must always be real-valued.
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
irfft = _add_docstr(
|
||||
@ -496,9 +482,7 @@ Example:
|
||||
>>> roundtrip = torch.fft.irfft(T, t.numel())
|
||||
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
rfft2 = _add_docstr(
|
||||
@ -565,9 +549,7 @@ Example:
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
|
||||
>>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
irfft2 = _add_docstr(
|
||||
@ -649,9 +631,7 @@ Example:
|
||||
torch.Size([10, 9])
|
||||
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
rfftn = _add_docstr(
|
||||
@ -718,9 +698,7 @@ Example:
|
||||
>>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
|
||||
>>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
irfftn = _add_docstr(
|
||||
@ -801,9 +779,7 @@ Example:
|
||||
torch.Size([10, 9])
|
||||
>>> torch.testing.assert_close(roundtrip, t, check_stride=False)
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
hfft = _add_docstr(
|
||||
@ -894,9 +870,7 @@ Example:
|
||||
|
||||
>>> torch.fft.hfft(T[:3])
|
||||
tensor([0.1250, 0.2809, 0.6250, 0.9691])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ihfft = _add_docstr(
|
||||
@ -951,9 +925,7 @@ Example:
|
||||
>>> torch.fft.ifft(t)
|
||||
tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
|
||||
-0.5000+0.6882j])
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
hfft2 = _add_docstr(
|
||||
@ -1025,9 +997,7 @@ Example:
|
||||
>>> torch.allclose(roundtrip, T)
|
||||
True
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ihfft2 = _add_docstr(
|
||||
@ -1092,9 +1062,7 @@ Example:
|
||||
>>> torch.allclose(t, two_ffts)
|
||||
True
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
hfftn = _add_docstr(
|
||||
@ -1187,9 +1155,7 @@ Example:
|
||||
>>> torch.allclose(roundtrip, T)
|
||||
True
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
ihfftn = _add_docstr(
|
||||
@ -1259,9 +1225,7 @@ Example:
|
||||
>>> torch.allclose(ihfftn, two_iffts)
|
||||
True
|
||||
|
||||
""".format(
|
||||
**common_args
|
||||
),
|
||||
""".format(**common_args),
|
||||
)
|
||||
|
||||
fftfreq = _add_docstr(
|
||||
@ -1310,9 +1274,7 @@ Example:
|
||||
>>> torch.fft.fftfreq(4)
|
||||
tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
|
||||
|
||||
""".format(
|
||||
**factory_common_args
|
||||
),
|
||||
""".format(**factory_common_args),
|
||||
)
|
||||
|
||||
rfftfreq = _add_docstr(
|
||||
@ -1361,9 +1323,7 @@ Example:
|
||||
>>> torch.fft.fftfreq(4)
|
||||
tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
|
||||
|
||||
""".format(
|
||||
**factory_common_args
|
||||
),
|
||||
""".format(**factory_common_args),
|
||||
)
|
||||
|
||||
fftshift = _add_docstr(
|
||||
|
@ -271,9 +271,9 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
...
|
||||
ValueError: foo
|
||||
"""
|
||||
assert isinstance(
|
||||
result, Exception
|
||||
), f"{result} is of type {type(result)}, not an Exception."
|
||||
assert isinstance(result, Exception), (
|
||||
f"{result} is of type {type(result)}, not an Exception."
|
||||
)
|
||||
|
||||
def raise_error(fut_result):
|
||||
raise fut_result
|
||||
|
@ -253,9 +253,9 @@ class _TensorPickleData:
|
||||
for k in MetaTensorDesc._UNSERIALIZABLE:
|
||||
if k in ("fake_mode", "view_func"):
|
||||
continue
|
||||
assert (
|
||||
getattr(self.metadata, k) is None
|
||||
), f"not None: {k}: {getattr(self.metadata, k)}"
|
||||
assert getattr(self.metadata, k) is None, (
|
||||
f"not None: {k}: {getattr(self.metadata, k)}"
|
||||
)
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
|
||||
# TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?
|
||||
|
@ -755,9 +755,9 @@ class Tracer(TracerBase):
|
||||
|
||||
self.root = root
|
||||
|
||||
assert hasattr(
|
||||
type(root), self.traced_func_name
|
||||
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
|
||||
assert hasattr(type(root), self.traced_func_name), (
|
||||
f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
|
||||
)
|
||||
|
||||
fn = getattr(type(root), self.traced_func_name)
|
||||
self.root_module_name = root._get_name()
|
||||
@ -1164,9 +1164,9 @@ def _maybe_revert_all_patches():
|
||||
finally:
|
||||
if current_patcher is not None:
|
||||
patches_made = current_patcher.reapply_all_patches()
|
||||
assert (
|
||||
patches_made == patches_removed
|
||||
), "CURRENT_PATCHER was changed during a revert_all_patches"
|
||||
assert patches_made == patches_removed, (
|
||||
"CURRENT_PATCHER was changed during a revert_all_patches"
|
||||
)
|
||||
|
||||
|
||||
def _patch_wrapped_functions(patcher: _Patcher):
|
||||
@ -1248,9 +1248,9 @@ def wrap(fn_or_name: Union[str, Callable]):
|
||||
assert not isinstance(fn_or_name, str) # to make mypy happy
|
||||
fn_name = fn_or_name.__name__
|
||||
else:
|
||||
assert isinstance(
|
||||
fn_or_name, str
|
||||
), "fn_or_name must be a global function or string name"
|
||||
assert isinstance(fn_or_name, str), (
|
||||
"fn_or_name must be a global function or string name"
|
||||
)
|
||||
fn_name = fn_or_name
|
||||
|
||||
currentframe = inspect.currentframe()
|
||||
@ -1308,7 +1308,9 @@ def symbolic_trace(
|
||||
return out
|
||||
|
||||
|
||||
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
|
||||
f = fx.symbolic_trace(
|
||||
f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}
|
||||
)
|
||||
assert f({"a": 1, "b": 2, "c": 4}) == 7
|
||||
|
||||
|
||||
|
@ -450,9 +450,9 @@ class Partitioner:
|
||||
device = find_device_based_on_size(node)
|
||||
occupied_devices.append(device)
|
||||
# Update partition and its left mem size
|
||||
partition_to_left_mem_bytes[
|
||||
partition
|
||||
] = device.available_mem_bytes
|
||||
partition_to_left_mem_bytes[partition] = (
|
||||
device.available_mem_bytes
|
||||
)
|
||||
# Update available mem for the current partition
|
||||
partition.logical_device_ids.append(device.logical_id)
|
||||
else:
|
||||
@ -475,9 +475,9 @@ class Partitioner:
|
||||
total_size_of_input_nodes = get_extra_size_of(
|
||||
node, partition.nodes
|
||||
)
|
||||
partition_to_left_mem_bytes[
|
||||
partition
|
||||
] = device.available_mem_bytes
|
||||
partition_to_left_mem_bytes[partition] = (
|
||||
device.available_mem_bytes
|
||||
)
|
||||
partition.logical_device_ids.append(device.logical_id)
|
||||
partition.add_node(node)
|
||||
partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
|
||||
@ -509,9 +509,9 @@ class Partitioner:
|
||||
no_device_partitions,
|
||||
) = get_device_partition_stats(self.partitions, self.devices)
|
||||
|
||||
assert (
|
||||
len(no_device_partitions) == 0
|
||||
), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
|
||||
assert len(no_device_partitions) == 0, (
|
||||
f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
|
||||
)
|
||||
|
||||
# Devices that hold partitions
|
||||
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
|
||||
|
@ -368,12 +368,12 @@ def optimize_for_inference(
|
||||
supports_mkldnn = MklSupport.YES
|
||||
sample_parameter = next(cur_module.parameters(), None)
|
||||
if sample_parameter is not None:
|
||||
assert (
|
||||
sample_parameter.dtype == torch.float
|
||||
), "this pass is only for torch.float modules"
|
||||
assert sample_parameter.device == torch.device(
|
||||
"cpu"
|
||||
), "this pass is only for CPU modules"
|
||||
assert sample_parameter.dtype == torch.float, (
|
||||
"this pass is only for torch.float modules"
|
||||
)
|
||||
assert sample_parameter.device == torch.device("cpu"), (
|
||||
"this pass is only for CPU modules"
|
||||
)
|
||||
elif node.op == "call_function":
|
||||
if node.target in mkldnn_supported:
|
||||
supports_mkldnn = MklSupport.YES
|
||||
|
@ -182,22 +182,19 @@ def is_sym_node(node: _HasMeta) -> bool:
|
||||
|
||||
|
||||
@overload
|
||||
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None:
|
||||
...
|
||||
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def set_proxy_slot(
|
||||
obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def set_proxy_slot(
|
||||
obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def set_proxy_slot(
|
||||
@ -256,8 +253,7 @@ _PySymProxyType = Thunk[Proxy]
|
||||
def get_proxy_slot(
|
||||
obj: Tensor,
|
||||
tracer: _ProxyTracer,
|
||||
) -> _ProxyTensor:
|
||||
...
|
||||
) -> _ProxyTensor: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -265,8 +261,7 @@ def get_proxy_slot(
|
||||
obj: Tensor,
|
||||
tracer: _ProxyTracer,
|
||||
default: U,
|
||||
) -> Union[_ProxyTensor, U]:
|
||||
...
|
||||
) -> Union[_ProxyTensor, U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -275,16 +270,14 @@ def get_proxy_slot(
|
||||
tracer: _ProxyTracer,
|
||||
default: U,
|
||||
transform: Callable[[_ProxyTensor], R],
|
||||
) -> Union[R, U]:
|
||||
...
|
||||
) -> Union[R, U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_proxy_slot(
|
||||
obj: _AnyScriptObjectType,
|
||||
tracer: _ProxyTracer,
|
||||
) -> Proxy:
|
||||
...
|
||||
) -> Proxy: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -292,8 +285,7 @@ def get_proxy_slot(
|
||||
obj: _AnyScriptObjectType,
|
||||
tracer: _ProxyTracer,
|
||||
default: U,
|
||||
) -> Union[Proxy, U]:
|
||||
...
|
||||
) -> Union[Proxy, U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -302,16 +294,14 @@ def get_proxy_slot(
|
||||
tracer: _ProxyTracer,
|
||||
default: U,
|
||||
transform: Callable[[Proxy], R],
|
||||
) -> Union[R, U]:
|
||||
...
|
||||
) -> Union[R, U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_proxy_slot(
|
||||
obj: PySymType,
|
||||
tracer: _ProxyTracer,
|
||||
) -> _PySymProxyType:
|
||||
...
|
||||
) -> _PySymProxyType: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -319,8 +309,7 @@ def get_proxy_slot(
|
||||
obj: PySymType,
|
||||
tracer: _ProxyTracer,
|
||||
default: T,
|
||||
) -> Union[T, _PySymProxyType]:
|
||||
...
|
||||
) -> Union[T, _PySymProxyType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -329,8 +318,7 @@ def get_proxy_slot(
|
||||
tracer: _ProxyTracer,
|
||||
default: U,
|
||||
transform: Callable[[_PySymProxyType], R],
|
||||
) -> Union[R, U]:
|
||||
...
|
||||
) -> Union[R, U]: ...
|
||||
|
||||
|
||||
# the default argument is what to return if the slot is not set.
|
||||
@ -717,22 +705,21 @@ def fetch_sym_proxy(
|
||||
|
||||
|
||||
@overload
|
||||
def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]:
|
||||
...
|
||||
def fetch_object_proxy(
|
||||
tracer: _ProxyTracer, t: Tensor
|
||||
) -> Union[_ProxyTensor, Tensor]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def fetch_object_proxy(
|
||||
tracer: _ProxyTracer, t: _AnyScriptObjectType
|
||||
) -> Union[Proxy, _AnyScriptObjectType]:
|
||||
...
|
||||
) -> Union[Proxy, _AnyScriptObjectType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def fetch_object_proxy(
|
||||
tracer: _ProxyTracer, t: PySymType
|
||||
) -> Union[_PySymProxyType, PySymType]:
|
||||
...
|
||||
) -> Union[_PySymProxyType, PySymType]: ...
|
||||
|
||||
|
||||
def fetch_object_proxy(
|
||||
@ -815,7 +802,10 @@ def proxy_call(
|
||||
|
||||
if func is torch.ops.aten.is_nonzero.default:
|
||||
with proxy_mode:
|
||||
torch._check(args[0].numel() == 1, lambda: "Boolean value of Tensor with more than one value is ambiguous") # type: ignore[attr-defined]
|
||||
torch._check(
|
||||
args[0].numel() == 1, # type: ignore[attr-defined]
|
||||
lambda: "Boolean value of Tensor with more than one value is ambiguous",
|
||||
)
|
||||
return (args[0] != 0).item() # type: ignore[attr-defined]
|
||||
|
||||
tracer = proxy_mode.tracer
|
||||
@ -1079,18 +1069,15 @@ class PythonKeyTracer(Tracer):
|
||||
return super().create_arg(a) # type: ignore[return-value]
|
||||
|
||||
@overload
|
||||
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
|
||||
...
|
||||
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ...
|
||||
|
||||
@overload
|
||||
def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]:
|
||||
...
|
||||
def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ...
|
||||
|
||||
@overload
|
||||
def unwrap_proxy(
|
||||
self, e: _AnyScriptObjectType
|
||||
) -> Union[Proxy, _AnyScriptObjectType]:
|
||||
...
|
||||
) -> Union[Proxy, _AnyScriptObjectType]: ...
|
||||
|
||||
def unwrap_proxy(self, e: T) -> object:
|
||||
if isinstance(e, Tensor):
|
||||
@ -1608,7 +1595,10 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||
|
||||
def placeholder(
|
||||
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
||||
@ -1617,7 +1607,10 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
return out
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
|
||||
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
||||
@ -1627,7 +1620,10 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
# call_function, call_method, call_module get traced automatically by the outer mode.
|
||||
|
||||
def output(
|
||||
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[object, ...],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
out = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||
|
||||
@ -1989,14 +1985,14 @@ class _MakefxTracer:
|
||||
# adding new modes in _MakefxTracer.
|
||||
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
||||
self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext()
|
||||
self.proxy_function_mode: Union[
|
||||
nullcontext, PreDispatchTorchFunctionMode
|
||||
] = nullcontext()
|
||||
self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = (
|
||||
nullcontext()
|
||||
)
|
||||
self.fx_tracer: Optional[PythonKeyTracer] = None
|
||||
self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext()
|
||||
self.torch_fn_metadata_mode: Union[
|
||||
nullcontext, TorchFunctionMetadataMode
|
||||
] = nullcontext()
|
||||
self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
|
||||
nullcontext()
|
||||
)
|
||||
self.stack_trace = stack_trace
|
||||
|
||||
def _checkpoint_modes(self) -> list[Any]:
|
||||
@ -2071,9 +2067,9 @@ class _MakefxTracer:
|
||||
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
||||
shape_env=shape_env,
|
||||
)
|
||||
assert (
|
||||
fake_tensor_mode.shape_env is not None
|
||||
), "shape_env should be set if tracing with 'symbolic'"
|
||||
assert fake_tensor_mode.shape_env is not None, (
|
||||
"shape_env should be set if tracing with 'symbolic'"
|
||||
)
|
||||
self.fake_tensor_mode = fake_tensor_mode
|
||||
else:
|
||||
if not self.tracing_mode == "real":
|
||||
@ -2161,9 +2157,9 @@ class _MakefxTracer:
|
||||
return self.fake_tensor_mode.from_tensor(x, source=source)
|
||||
# NB: don't match on bools
|
||||
elif type(x) is int and self.tracing_mode == "symbolic":
|
||||
assert (
|
||||
self.fake_tensor_mode.shape_env is not None
|
||||
), "shape_env should be set if tracing with 'symbolic'"
|
||||
assert self.fake_tensor_mode.shape_env is not None, (
|
||||
"shape_env should be set if tracing with 'symbolic'"
|
||||
)
|
||||
return self.fake_tensor_mode.shape_env.create_symintnode(
|
||||
self.fake_tensor_mode.shape_env.create_symbol(
|
||||
x, source, positive=None
|
||||
@ -2176,9 +2172,9 @@ class _MakefxTracer:
|
||||
self.fake_tensor_mode, x
|
||||
)
|
||||
|
||||
assert not isinstance(
|
||||
x, FakeScriptObject
|
||||
), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
|
||||
assert not isinstance(x, FakeScriptObject), (
|
||||
f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
|
||||
)
|
||||
return x
|
||||
|
||||
wrap_fn_map = {
|
||||
@ -2344,9 +2340,9 @@ def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
)
|
||||
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
||||
assert (
|
||||
pre_dispatch_mode is None or mode is None
|
||||
), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
|
||||
assert pre_dispatch_mode is None or mode is None, (
|
||||
f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
|
||||
)
|
||||
return pre_dispatch_mode or mode
|
||||
|
||||
|
||||
|
@ -460,7 +460,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
|
||||
# Here, we allow the value of each field to be mapped, so that we appropriately
|
||||
# compare the two values.
|
||||
def compare_vars(
|
||||
map_value: Callable[[str, Any], Any]
|
||||
map_value: Callable[[str, Any], Any],
|
||||
) -> list[tuple[str, str, str]]:
|
||||
env1_set, env2_set = set(env1_vars), set(env2_vars)
|
||||
|
||||
|
@ -103,7 +103,7 @@ class AnnotateTypesWithSchema(Transformer):
|
||||
for i, atom in enumerate(atoms):
|
||||
if not hasattr(module_itr, atom):
|
||||
raise RuntimeError(
|
||||
f'Node referenced nonextent target {".".join(atoms[:i])}!'
|
||||
f"Node referenced nonextent target {'.'.join(atoms[:i])}!"
|
||||
)
|
||||
module_itr = getattr(module_itr, atom)
|
||||
|
||||
|
@ -149,9 +149,9 @@ class SymNode:
|
||||
# This is technically not TV, but this assert is expensive so
|
||||
# let's only do it when we're already doing expensive things
|
||||
computed_hint = compute_hint()
|
||||
assert (
|
||||
hint == computed_hint
|
||||
), f"{hint} != {computed_hint} (for {self.expr})"
|
||||
assert hint == computed_hint, (
|
||||
f"{hint} != {computed_hint} (for {self.expr})"
|
||||
)
|
||||
else:
|
||||
hint = compute_hint()
|
||||
self._hint = hint
|
||||
@ -460,7 +460,9 @@ class SymNode:
|
||||
return self.float_pow(other)
|
||||
|
||||
def is_non_overlapping_and_dense(self, sizes, strides):
|
||||
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
|
||||
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
|
||||
to_node(self, 1)
|
||||
) # type: ignore[attr-defined]
|
||||
|
||||
def int_(self):
|
||||
return self.guard_int("", 0) # NB: uses Python backtrace
|
||||
|
@ -182,7 +182,9 @@ CURRENT_NODE_KEY = "current_node"
|
||||
|
||||
def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None:
|
||||
log.debug(
|
||||
"lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() # type: ignore[attr-defined]
|
||||
"lru_cache_stats %s: %s",
|
||||
wrapped_f.__name__, # type: ignore[attr-defined]
|
||||
wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
@ -497,9 +499,9 @@ def check_consistent(new: _T, old: _T) -> None:
|
||||
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
|
||||
# NB: bool is subclass of int
|
||||
elif isinstance(new, scalar_types) and not isinstance(new, bool):
|
||||
assert isinstance(old, scalar_types) and not isinstance(
|
||||
old, bool
|
||||
), f"{old} != {new}"
|
||||
assert isinstance(old, scalar_types) and not isinstance(old, bool), (
|
||||
f"{old} != {new}"
|
||||
)
|
||||
torch._check(old == new, lambda: f"{old} != {new} (old != new)")
|
||||
|
||||
|
||||
@ -629,9 +631,9 @@ def rebind_unbacked(
|
||||
raw_u1 = new_raw_u1
|
||||
|
||||
if not isinstance(raw_u1, sympy.Symbol):
|
||||
assert (
|
||||
not raw_u1.free_symbols
|
||||
), f"should have been constant, but got {raw_u1}"
|
||||
assert not raw_u1.free_symbols, (
|
||||
f"should have been constant, but got {raw_u1}"
|
||||
)
|
||||
continue
|
||||
|
||||
# The old and new could be the same if you improperly hit the memo
|
||||
@ -1975,12 +1977,12 @@ class EqualityConstraint(Constraint):
|
||||
|
||||
|
||||
def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]:
|
||||
assert isinstance(
|
||||
symbolic_context, SymbolicContext
|
||||
), "Invalid symbolic_context object"
|
||||
assert (
|
||||
type(symbolic_context) is not SymbolicContext
|
||||
), "Illegal usage of symbolic_context ABC"
|
||||
assert isinstance(symbolic_context, SymbolicContext), (
|
||||
"Invalid symbolic_context object"
|
||||
)
|
||||
assert type(symbolic_context) is not SymbolicContext, (
|
||||
"Illegal usage of symbolic_context ABC"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@ -2519,9 +2521,9 @@ def _lru_cache(
|
||||
prior_version = self._version_counter
|
||||
prior_key = self._get_key()
|
||||
else:
|
||||
assert (
|
||||
prior_key == self._get_key()
|
||||
), "ShapeEnv cache key changed without version being updated!"
|
||||
assert prior_key == self._get_key(), (
|
||||
"ShapeEnv cache key changed without version being updated!"
|
||||
)
|
||||
|
||||
return fn_cache(self, *args, **kwargs)
|
||||
|
||||
@ -2772,9 +2774,9 @@ class DynamicDimConstraintPrinter(PythonPrinter):
|
||||
|
||||
def _print_Symbol(self, expr: sympy.Symbol) -> str:
|
||||
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
||||
assert self.symbol_to_source.get(
|
||||
expr
|
||||
), f"Unknown symbol {expr} created by constraints solver"
|
||||
assert self.symbol_to_source.get(expr), (
|
||||
f"Unknown symbol {expr} created by constraints solver"
|
||||
)
|
||||
return self.symbol_to_source[expr][0].name()
|
||||
|
||||
|
||||
@ -2792,9 +2794,9 @@ class DimConstraints:
|
||||
source_name_to_debug_name: Mapping[str, str],
|
||||
) -> None:
|
||||
# We try to solve systems of inequalities with 1 free variable.
|
||||
self._univariate_inequalities: dict[
|
||||
sympy.Symbol, set[SympyBoolean]
|
||||
] = defaultdict(set)
|
||||
self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = (
|
||||
defaultdict(set)
|
||||
)
|
||||
# Among them, we prioritize solving for a free variable that has equalities.
|
||||
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
|
||||
# and removing a symbol from the former => removing it from the latter.
|
||||
@ -2877,9 +2879,10 @@ class DimConstraints:
|
||||
# With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
|
||||
# would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
|
||||
base, divisor = args
|
||||
base, divisor = self.rewrite_with_congruences(
|
||||
s, base
|
||||
), self.rewrite_with_congruences(s, divisor)
|
||||
base, divisor = (
|
||||
self.rewrite_with_congruences(s, base),
|
||||
self.rewrite_with_congruences(s, divisor),
|
||||
)
|
||||
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
|
||||
self._var_to_val
|
||||
)
|
||||
@ -2896,9 +2899,10 @@ class DimConstraints:
|
||||
# NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
|
||||
# and eliminating b % d as above.
|
||||
base, divisor = args
|
||||
base, divisor = self.rewrite_with_congruences(
|
||||
s, base
|
||||
), self.rewrite_with_congruences(s, divisor)
|
||||
base, divisor = (
|
||||
self.rewrite_with_congruences(s, base),
|
||||
self.rewrite_with_congruences(s, divisor),
|
||||
)
|
||||
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(
|
||||
self._var_to_val
|
||||
)
|
||||
@ -3060,9 +3064,9 @@ class DimConstraints:
|
||||
(arg for arg in solution.args if isinstance(arg, sympy.Eq)),
|
||||
solution,
|
||||
)
|
||||
assert isinstance(
|
||||
solution, sympy.Eq
|
||||
), f"Expected an equality constraint for {s}, got {solution}"
|
||||
assert isinstance(solution, sympy.Eq), (
|
||||
f"Expected an equality constraint for {s}, got {solution}"
|
||||
)
|
||||
symbol, val = solution.args
|
||||
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
|
||||
# because this is univariate, the solution is a specialization
|
||||
@ -3340,7 +3344,8 @@ class DimConstraints:
|
||||
"max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index]
|
||||
}
|
||||
if not _check_same_range(
|
||||
result, name_to_dim[mroot] # type: ignore[index, arg-type]
|
||||
result,
|
||||
name_to_dim[mroot], # type: ignore[index, arg-type]
|
||||
): # ignore if unchanged
|
||||
modified_root_values[mroot] = result # type: ignore[index]
|
||||
break
|
||||
@ -4124,9 +4129,9 @@ class ShapeEnv:
|
||||
if not isinstance(b, SymInt):
|
||||
assert a == b
|
||||
else:
|
||||
assert isinstance(
|
||||
b.node.expr, sympy.Symbol
|
||||
), "constraining non-Symbols NYI"
|
||||
assert isinstance(b.node.expr, sympy.Symbol), (
|
||||
"constraining non-Symbols NYI"
|
||||
)
|
||||
assert b.node.shape_env is self
|
||||
self.replacements[b.node.expr] = sympy.Integer(a)
|
||||
else:
|
||||
@ -4139,9 +4144,9 @@ class ShapeEnv:
|
||||
self.replacements[a.node.expr] = sympy.Integer(b)
|
||||
else:
|
||||
assert a.node.shape_env is b.node.shape_env
|
||||
assert isinstance(
|
||||
b.node.expr, sympy.Symbol
|
||||
), "constraining non-Symbols NYI"
|
||||
assert isinstance(b.node.expr, sympy.Symbol), (
|
||||
"constraining non-Symbols NYI"
|
||||
)
|
||||
new_var = self._find(a.node.expr)
|
||||
self.replacements[b.node.expr] = new_var
|
||||
|
||||
@ -4234,9 +4239,9 @@ class ShapeEnv:
|
||||
|
||||
# If translation validation is enabled, all arguments must have its
|
||||
# own FX node.
|
||||
assert all(
|
||||
a is not None for a in args
|
||||
), f"missing arg in FX graph ({op.__name__}): {args}"
|
||||
assert all(a is not None for a in args), (
|
||||
f"missing arg in FX graph ({op.__name__}): {args}"
|
||||
)
|
||||
node = self.fx_node_cache[node_key] = self.graph.call_function(op, args)
|
||||
self.name_to_node[node.name] = node
|
||||
|
||||
@ -4354,9 +4359,9 @@ class ShapeEnv:
|
||||
source: Source,
|
||||
symbolic_context: SymbolicContext,
|
||||
) -> list[sympy.Expr]:
|
||||
assert all(
|
||||
not is_symbolic(val) for val in tensor_size
|
||||
), f"Expect size to be a plain tuple of ints but got {tensor_size}"
|
||||
assert all(not is_symbolic(val) for val in tensor_size), (
|
||||
f"Expect size to be a plain tuple of ints but got {tensor_size}"
|
||||
)
|
||||
from torch._dynamo.source import TensorProperty, TensorPropertySource
|
||||
|
||||
_assert_symbol_context(symbolic_context)
|
||||
@ -4398,7 +4403,11 @@ class ShapeEnv:
|
||||
source: Source,
|
||||
*,
|
||||
symbolic_context: Optional[SymbolicContext] = None,
|
||||
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
|
||||
) -> tuple[
|
||||
tuple[IntLikeType, ...],
|
||||
tuple[IntLikeType, ...],
|
||||
IntLikeType,
|
||||
]:
|
||||
"""
|
||||
Returns a list of symbolic sizes and strides for the given tensor.
|
||||
We try our best to express stride in terms of the sizes, so as to not
|
||||
@ -4463,9 +4472,9 @@ class ShapeEnv:
|
||||
) -> IntLikeType:
|
||||
assert isinstance(maybe_sym, (int, torch.SymInt))
|
||||
if is_symbolic(maybe_sym):
|
||||
assert (
|
||||
maybe_sym.node.shape_env is not self
|
||||
), "expect the symbol is created from an shape env other than current one."
|
||||
assert maybe_sym.node.shape_env is not self, (
|
||||
"expect the symbol is created from an shape env other than current one."
|
||||
)
|
||||
return maybe_sym.node.require_hint()
|
||||
return maybe_sym
|
||||
|
||||
@ -4481,7 +4490,11 @@ class ShapeEnv:
|
||||
source: Source,
|
||||
*,
|
||||
symbolic_context: Optional[SymbolicContext] = None,
|
||||
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
|
||||
) -> tuple[
|
||||
tuple[IntLikeType, ...],
|
||||
tuple[IntLikeType, ...],
|
||||
IntLikeType,
|
||||
]:
|
||||
dim = len(ex_size)
|
||||
|
||||
# Reimplement the legacy behavior
|
||||
@ -5045,9 +5058,9 @@ class ShapeEnv:
|
||||
sloc,
|
||||
)
|
||||
else:
|
||||
self.var_to_range[
|
||||
sympy_expr
|
||||
] = self._default_unspecified_value_range()
|
||||
self.var_to_range[sympy_expr] = (
|
||||
self._default_unspecified_value_range()
|
||||
)
|
||||
self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc)
|
||||
|
||||
# Small performance optimization: if we have a min-max constraint,
|
||||
@ -5238,9 +5251,9 @@ class ShapeEnv:
|
||||
shape_env = replay_shape_env_events(self.events)
|
||||
self.check_equal(shape_env)
|
||||
|
||||
assert len(placeholders) == len(
|
||||
sources
|
||||
), f"len({placeholders}) != len({sources})"
|
||||
assert len(placeholders) == len(sources), (
|
||||
f"len({placeholders}) != len({sources})"
|
||||
)
|
||||
Tensorlike = (torch.Tensor, FakeTensorMeta)
|
||||
|
||||
def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext:
|
||||
@ -5336,9 +5349,9 @@ class ShapeEnv:
|
||||
symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
|
||||
list
|
||||
)
|
||||
symbol_to_constraints: defaultdict[
|
||||
sympy.Symbol, set[Constraint]
|
||||
] = collections.defaultdict(set)
|
||||
symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = (
|
||||
collections.defaultdict(set)
|
||||
)
|
||||
constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
|
||||
|
||||
printers: list[_ShapeGuardPrinter] = []
|
||||
@ -6528,7 +6541,7 @@ class ShapeEnv:
|
||||
f"Caused by: {sloc}\n"
|
||||
'For more information, run with TORCH_LOGS="dynamic"\n'
|
||||
"For extended logs when we create symbols, also add "
|
||||
f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n"
|
||||
f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n'
|
||||
"If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
|
||||
"For more debugging help, see "
|
||||
"https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n"
|
||||
@ -6662,9 +6675,9 @@ class ShapeEnv:
|
||||
)
|
||||
self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
|
||||
tgt_bound = self.bound_sympy(tgt)
|
||||
assert tgt_bound.issubset(
|
||||
src_bound
|
||||
), f"{tgt_bound=} not a subset of {src_bound=}"
|
||||
assert tgt_bound.issubset(src_bound), (
|
||||
f"{tgt_bound=} not a subset of {src_bound=}"
|
||||
)
|
||||
|
||||
# TODO: Should we propagate size-like-ness?
|
||||
#
|
||||
@ -6751,9 +6764,9 @@ class ShapeEnv:
|
||||
for source in self.var_to_sources.get(a, []):
|
||||
if user_tb:
|
||||
self.user_specialization_stacks[source] = user_tb
|
||||
self.framework_specialization_stacks[
|
||||
source
|
||||
] = CapturedTraceback.extract(cpp=True)
|
||||
self.framework_specialization_stacks[source] = (
|
||||
CapturedTraceback.extract(cpp=True)
|
||||
)
|
||||
|
||||
if config.print_specializations:
|
||||
self.log.warning(
|
||||
@ -6820,9 +6833,9 @@ class ShapeEnv:
|
||||
|
||||
free = list(expr.free_symbols)
|
||||
|
||||
assert (
|
||||
len(free) > 0
|
||||
), f"The expression should not be static by this point: {expr}"
|
||||
assert len(free) > 0, (
|
||||
f"The expression should not be static by this point: {expr}"
|
||||
)
|
||||
# In case of really gnarly expression, we don't blow up
|
||||
if len(free) > 5:
|
||||
return
|
||||
|
@ -203,9 +203,7 @@ try:
|
||||
return _Z3Ops.to_real(result) if cast_result_to_real else result
|
||||
|
||||
def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
|
||||
return z3.If(
|
||||
self.floor(number) < number, self.floor(number + 1), number
|
||||
) # type: ignore[return-value]
|
||||
return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value]
|
||||
|
||||
def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
|
||||
return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]
|
||||
@ -363,9 +361,9 @@ try:
|
||||
return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type]
|
||||
# Adds the Z3 expression corresponding to the first argument
|
||||
# as a validator input.
|
||||
assert (
|
||||
len(args) == 1
|
||||
), f"expected 1 argument on assertion. Got: {len(args)} "
|
||||
assert len(args) == 1, (
|
||||
f"expected 1 argument on assertion. Got: {len(args)} "
|
||||
)
|
||||
self.validator.add_source_expr(args[0]) # type: ignore[arg-type]
|
||||
|
||||
# Translates SymPy expressions into Z3 expressions.
|
||||
@ -536,9 +534,9 @@ try:
|
||||
|
||||
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
|
||||
z3expr = SympyToZ3(self).run(e)
|
||||
assert isinstance(
|
||||
z3expr, z3.BoolRef
|
||||
), f"expected boolean expression. Got: {z3expr}"
|
||||
assert isinstance(z3expr, z3.BoolRef), (
|
||||
f"expected boolean expression. Got: {z3expr}"
|
||||
)
|
||||
return z3expr
|
||||
|
||||
def add_source_expr(self, e: z3.BoolRef) -> None:
|
||||
|
@ -449,7 +449,7 @@ class CodeGen:
|
||||
# This code-path used in Python < 3.9
|
||||
return origin_typename
|
||||
|
||||
return f'{origin_typename}[{",".join(args)}]'
|
||||
return f"{origin_typename}[{','.join(args)}]"
|
||||
else:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python 3.9+
|
||||
@ -573,7 +573,7 @@ class CodeGen:
|
||||
summary_str = parsed_stack_trace.get_summary_str()
|
||||
else:
|
||||
summary_str = ""
|
||||
body.append(f'\n {dim(f"# {summary_str}")}\n')
|
||||
body.append(f"\n {dim(f'# {summary_str}')}\n")
|
||||
elif prev_stacktrace != "":
|
||||
prev_stacktrace = ""
|
||||
no_stacktrace_msg = "# No stacktrace found for following nodes"
|
||||
@ -842,7 +842,7 @@ class _PyTreeCodeGen(CodeGen):
|
||||
if len(has_annotation) > 0:
|
||||
fn_definition += "\n " + "".join(has_annotation) + "\n"
|
||||
fn_definition += f"""
|
||||
{', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
||||
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
||||
return fn_definition
|
||||
|
||||
def generate_output(self, output_args):
|
||||
@ -1877,7 +1877,9 @@ class Graph:
|
||||
# through `insert_pdb`:
|
||||
gm.graph.on_generate_code(
|
||||
lambda current_trans: (
|
||||
lambda body: insert_pdb(current_trans(body) if current_trans else body)
|
||||
lambda body: insert_pdb(
|
||||
current_trans(body) if current_trans else body
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -1916,7 +1918,7 @@ class Graph:
|
||||
|
||||
@contextmanager
|
||||
def _override_sym_repr(
|
||||
override: Callable[["torch.types.PySymType"], str]
|
||||
override: Callable[["torch.types.PySymType"], str],
|
||||
) -> Iterator[None]:
|
||||
tmp = CodeGen._sym_repr
|
||||
try:
|
||||
|
@ -324,9 +324,9 @@ def _print_readable(
|
||||
colored=False,
|
||||
):
|
||||
graph = module.graph
|
||||
assert graph is not None and isinstance(
|
||||
graph, torch.fx.Graph
|
||||
), "print_readable must be used on a module with a graph"
|
||||
assert graph is not None and isinstance(graph, torch.fx.Graph), (
|
||||
"print_readable must be used on a module with a graph"
|
||||
)
|
||||
|
||||
verbose_python_code = graph.python_code(
|
||||
root_module="self",
|
||||
@ -873,9 +873,9 @@ class {module_name}(torch.nn.Module):
|
||||
for node in self.graph.nodes
|
||||
if "stack_trace" in node.meta
|
||||
}
|
||||
dict_without_graph[
|
||||
"_graphmodule_graph_node_meta_stack_trace"
|
||||
] = node_meta_stack_trace
|
||||
dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = (
|
||||
node_meta_stack_trace
|
||||
)
|
||||
|
||||
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
|
||||
python_code = self.recompile()
|
||||
|
@ -51,7 +51,9 @@ class Interpreter:
|
||||
method equivalents). We could subclass Interpreter like so::
|
||||
|
||||
class NegSigmSwapInterpreter(Interpreter):
|
||||
def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
|
||||
def call_function(
|
||||
self, target: Target, args: Tuple, kwargs: Dict
|
||||
) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(target, args, kwargs)
|
||||
@ -405,7 +407,7 @@ class Interpreter:
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}"
|
||||
f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
@ -468,14 +470,20 @@ class Transformer(Interpreter):
|
||||
|
||||
class NegSigmSwapXformer(Transformer):
|
||||
def call_function(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
self,
|
||||
target: "Target",
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(target, args, kwargs)
|
||||
|
||||
def call_method(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
self,
|
||||
target: "Target",
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
if target == "neg":
|
||||
call_self, *args_tail = args
|
||||
|
@ -514,9 +514,9 @@ class Node(_NodeBase):
|
||||
idx (int): The index of the element in ``self.args`` to be inserted before.
|
||||
arg (Argument): The new argument value to insert into ``args``
|
||||
"""
|
||||
assert (
|
||||
0 <= idx <= len(self.args)
|
||||
), "insert_args index must be between 0 and len(self.args)"
|
||||
assert 0 <= idx <= len(self.args), (
|
||||
"insert_args index must be between 0 and len(self.args)"
|
||||
)
|
||||
args_left = self.args[:idx]
|
||||
args_right = self.args[idx:]
|
||||
|
||||
@ -747,13 +747,13 @@ class Node(_NodeBase):
|
||||
|
||||
# Check if an impure module.
|
||||
if self.op == "call_module":
|
||||
assert (
|
||||
self.graph.owning_module is not None
|
||||
), "self.graph.owning_module not set for purity check"
|
||||
assert self.graph.owning_module is not None, (
|
||||
"self.graph.owning_module not set for purity check"
|
||||
)
|
||||
target_mod = self.graph.owning_module.get_submodule(self.target)
|
||||
assert (
|
||||
target_mod is not None
|
||||
), f"Did not find expected submodule target {self.target}"
|
||||
assert target_mod is not None, (
|
||||
f"Did not find expected submodule target {self.target}"
|
||||
)
|
||||
return getattr(target_mod, "_is_impure", False)
|
||||
|
||||
return False
|
||||
|
@ -770,9 +770,9 @@ class _MinimizerBase:
|
||||
node_name = node.name
|
||||
if node_name is not None and isinstance(node_name, tuple):
|
||||
node_name = node_name[0]
|
||||
assert node_name is not None and isinstance(
|
||||
node_name, str
|
||||
), f"minimize: node_name: {node_name}"
|
||||
assert node_name is not None and isinstance(node_name, str), (
|
||||
f"minimize: node_name: {node_name}"
|
||||
)
|
||||
|
||||
report.append(f"Add node: {node_name}")
|
||||
|
||||
|
@ -93,9 +93,9 @@ def loop_pass(
|
||||
predicate (Callable[Object, bool], optional):
|
||||
|
||||
"""
|
||||
assert (n_iter is not None) ^ (
|
||||
predicate is not None
|
||||
), "Exactly one of `n_iter`or `predicate` must be specified."
|
||||
assert (n_iter is not None) ^ (predicate is not None), (
|
||||
"Exactly one of `n_iter`or `predicate` must be specified."
|
||||
)
|
||||
|
||||
@wraps(base_pass)
|
||||
def new_pass(source):
|
||||
|
@ -397,7 +397,9 @@ def insert_deferred_runtime_asserts(
|
||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||
),
|
||||
):
|
||||
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
|
||||
expr_to_proxy[sym_expr] = _sympy_interp(
|
||||
expr_to_proxy, sym_expr
|
||||
) # type: ignore[arg-type]
|
||||
# won't try DCE-ing tensor compute here
|
||||
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
|
||||
node.replace_all_uses_with(hash_node)
|
||||
|
@ -199,9 +199,9 @@ def split_by_tags(
|
||||
mx = max((c.order for c in upstream_components), default=0)
|
||||
|
||||
# Expect the component for `node` has higher order then its upstream components.
|
||||
assert (
|
||||
comp.order >= mx
|
||||
), f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
|
||||
assert comp.order >= mx, (
|
||||
f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}"
|
||||
)
|
||||
|
||||
# Map a input of `node` to nodes in the component's graph.
|
||||
def remap_func(x):
|
||||
|
@ -36,9 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList:
|
||||
if indegree_map[n] == 0:
|
||||
candidates.put(n)
|
||||
|
||||
assert len(nodes) == len(
|
||||
sorted_nodes
|
||||
), "topological sorted nodes doesn't have same length as input nodes"
|
||||
assert len(nodes) == len(sorted_nodes), (
|
||||
"topological sorted nodes doesn't have same length as input nodes"
|
||||
)
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
@ -127,13 +127,13 @@ def fuse_as_graphmodule(
|
||||
# assumption: nodes are already sorted in topo order
|
||||
|
||||
for node in nodes:
|
||||
assert (
|
||||
node.graph.owning_module is gm
|
||||
), f"{node} doesn't belong to passed in graph module {gm._get_name()}"
|
||||
assert node.graph.owning_module is gm, (
|
||||
f"{node} doesn't belong to passed in graph module {gm._get_name()}"
|
||||
)
|
||||
assert not node._erased, f"{node} has been removed from owning graph"
|
||||
assert (
|
||||
node in gm.graph._find_nodes_lookup_table
|
||||
), f"{node} is not found in graph module {gm._get_name()}"
|
||||
assert node in gm.graph._find_nodes_lookup_table, (
|
||||
f"{node} is not found in graph module {gm._get_name()}"
|
||||
)
|
||||
|
||||
# validates partition doesn't introduce dependency circles in the graph
|
||||
assert validate_partition(nodes), "Invalid partition, found dependency cycles"
|
||||
|
@ -96,9 +96,9 @@ class SubgraphMatcher:
|
||||
|
||||
for node in pattern.nodes:
|
||||
if node.op != "output":
|
||||
assert (
|
||||
len(node.users) > 0
|
||||
), "SubgraphMatcher cannot be initialized with an pattern with dead code"
|
||||
assert len(node.users) > 0, (
|
||||
"SubgraphMatcher cannot be initialized with an pattern with dead code"
|
||||
)
|
||||
|
||||
# TODO: assert pattern is a connected graph
|
||||
|
||||
@ -192,9 +192,9 @@ class SubgraphMatcher:
|
||||
return non_overlapping_matches
|
||||
|
||||
def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
|
||||
assert not (
|
||||
isinstance(pn, Node) and isinstance(gn, Node)
|
||||
), "pn and gn cannot both be Node"
|
||||
assert not (isinstance(pn, Node) and isinstance(gn, Node)), (
|
||||
"pn and gn cannot both be Node"
|
||||
)
|
||||
|
||||
if isinstance(pn, Node) and not isinstance(gn, Node):
|
||||
if pn.op == "placeholder":
|
||||
|
@ -18,17 +18,17 @@ def _split_to_graph_and_name_node_map(
|
||||
if n.op == "output":
|
||||
assert gm._out_spec is not None
|
||||
output = tree_unflatten(n.args[0], gm._out_spec)
|
||||
assert isinstance(
|
||||
output, tuple
|
||||
), "Expecting the pattern graph to return a tuple"
|
||||
assert (
|
||||
len(output) >= 2
|
||||
), "Expecting the pattern graph to have at least two outputs"
|
||||
assert isinstance(output, tuple), (
|
||||
"Expecting the pattern graph to return a tuple"
|
||||
)
|
||||
assert len(output) >= 2, (
|
||||
"Expecting the pattern graph to have at least two outputs"
|
||||
)
|
||||
*out, name_node_map = output
|
||||
flattened, out_spec = tree_flatten(out)
|
||||
assert isinstance(
|
||||
name_node_map, dict
|
||||
), "Expecting the input graph to have a dict output as the last element"
|
||||
assert isinstance(name_node_map, dict), (
|
||||
"Expecting the input graph to have a dict output as the last element"
|
||||
)
|
||||
n.args = (flattened,)
|
||||
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
|
||||
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
|
||||
@ -53,12 +53,14 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
|
||||
relu = F.relu(conv)
|
||||
return relu, {"conv": conv, "relu": relu}
|
||||
|
||||
|
||||
def target_graph(x, weight):
|
||||
conv = F.conv2d(x, weight)
|
||||
relu = F.relu(conv)
|
||||
relu *= 2
|
||||
return relu
|
||||
|
||||
|
||||
pattern_gm = export_for_training(pattern, example_inputs).module()
|
||||
target_gm = export_for_training(target_graph, example_inputs).module()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
|
@ -654,9 +654,9 @@ class MetaProxy(Proxy):
|
||||
meta_proxy = arg
|
||||
break
|
||||
|
||||
assert (
|
||||
meta_proxy is not None
|
||||
), "No MetaProxy found in arguments, but one is expected."
|
||||
assert meta_proxy is not None, (
|
||||
"No MetaProxy found in arguments, but one is expected."
|
||||
)
|
||||
|
||||
proxy = super().__torch_function__(orig_method, types, args, kwargs)
|
||||
with meta_proxy.fake_mode:
|
||||
@ -739,14 +739,14 @@ for method in magic_methods:
|
||||
return tracer.create_proxy("call_function", target, args, kwargs)
|
||||
|
||||
impl.__name__ = method
|
||||
as_magic = f'__{method.strip("_")}__'
|
||||
as_magic = f"__{method.strip('_')}__"
|
||||
setattr(Proxy, as_magic, impl)
|
||||
|
||||
_scope(method)
|
||||
|
||||
|
||||
def _define_reflectable(orig_method_name):
|
||||
method_name = f'__r{orig_method_name.strip("_")}__'
|
||||
method_name = f"__r{orig_method_name.strip('_')}__"
|
||||
|
||||
def impl(self, rhs):
|
||||
target = getattr(operator, orig_method_name)
|
||||
|
@ -307,9 +307,9 @@ def _replace_pattern(
|
||||
elif callable(replacement):
|
||||
common_replacement_graph = symbolic_trace(replacement).graph
|
||||
else:
|
||||
assert (
|
||||
replacement_callback is not None
|
||||
), "Must provide either a replacement GraphModule or a replacement callback"
|
||||
assert replacement_callback is not None, (
|
||||
"Must provide either a replacement GraphModule or a replacement callback"
|
||||
)
|
||||
common_replacement_graph = None
|
||||
|
||||
# As we progressively replace nodes, we'll need to keep track of how the match results should change
|
||||
@ -322,9 +322,9 @@ def _replace_pattern(
|
||||
match, original_graph, pattern_graph
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
common_replacement_graph is not None
|
||||
), "Must provide either a replacement GraphModule or a replacement callback"
|
||||
assert common_replacement_graph is not None, (
|
||||
"Must provide either a replacement GraphModule or a replacement callback"
|
||||
)
|
||||
replacement_graph = common_replacement_graph
|
||||
replacement_placeholders = [
|
||||
n for n in replacement_graph.nodes if n.op == "placeholder"
|
||||
|
@ -18,7 +18,15 @@ from torch.nn.modules.utils import (
|
||||
|
||||
_builtin_table: Optional[dict[int, str]] = None
|
||||
|
||||
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
|
||||
_modules_containing_builtins = (
|
||||
torch,
|
||||
torch._C._nn,
|
||||
torch._C._fft, # type: ignore[attr-defined]
|
||||
torch._C._linalg, # type: ignore[attr-defined]
|
||||
torch._C._nested, # type: ignore[attr-defined]
|
||||
torch._C._sparse, # type: ignore[attr-defined]
|
||||
torch._C._special, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
_builtin_ops = [
|
||||
# Pairs of (function, op_name)
|
||||
@ -94,7 +102,10 @@ _builtin_ops = [
|
||||
(torch.autograd.grad, "aten::grad"),
|
||||
(torch.autograd.backward, "aten::backward"),
|
||||
(torch._C._infer_size, "aten::_infer_size"),
|
||||
(torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined]
|
||||
(
|
||||
torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined]
|
||||
"aten::_no_grad_embedding_renorm_",
|
||||
),
|
||||
(torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
|
||||
(torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
|
||||
(torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
|
||||
|
@ -4,9 +4,9 @@ from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
|
||||
def _register_decomposition(op: OpOverload, graph: torch._C.Graph):
|
||||
assert not isinstance(
|
||||
op, OpOverloadPacket
|
||||
), f"Must pass specific op overload, not overload packet, found {op}"
|
||||
assert not isinstance(op, OpOverloadPacket), (
|
||||
f"Must pass specific op overload, not overload packet, found {op}"
|
||||
)
|
||||
assert isinstance(op, OpOverload)
|
||||
|
||||
torch._C._jit_register_decomposition_for_schema(op._schema, graph)
|
||||
|
@ -23,13 +23,13 @@ def check_decomposition_has_type_annotations(f):
|
||||
inspect_empty = inspect._empty # type: ignore[attr-defined]
|
||||
sig = inspect.signature(f)
|
||||
for param in sig.parameters.values():
|
||||
assert (
|
||||
param.annotation != inspect_empty
|
||||
), f"No signature on param {param.name} for function {f.name}"
|
||||
assert param.annotation != inspect_empty, (
|
||||
f"No signature on param {param.name} for function {f.name}"
|
||||
)
|
||||
|
||||
assert (
|
||||
sig.return_annotation != inspect_empty
|
||||
), f"No return annotation for function {f.name}"
|
||||
assert sig.return_annotation != inspect_empty, (
|
||||
f"No return annotation for function {f.name}"
|
||||
)
|
||||
|
||||
|
||||
def signatures_match(decomposition_sig, torch_op_sig):
|
||||
@ -75,9 +75,9 @@ def register_decomposition(
|
||||
assert isinstance(aten_op, torch._ops.OpOverload)
|
||||
|
||||
# Need unique name for jit function serialization
|
||||
assert (
|
||||
f.__name__ not in function_name_set
|
||||
), f"Duplicated function name {f.__name__}"
|
||||
assert f.__name__ not in function_name_set, (
|
||||
f"Duplicated function name {f.__name__}"
|
||||
)
|
||||
function_name_set.add(f.__name__)
|
||||
|
||||
scripted_func = torch.jit.script(f)
|
||||
|
@ -588,9 +588,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn):
|
||||
# recursively scripting them.
|
||||
for name, sub_concrete_type in concrete_type.get_modules():
|
||||
orig_value = getattr(nn_module, name)
|
||||
assert isinstance(
|
||||
orig_value, Module
|
||||
), f"Expected Module but got {type(orig_value)}"
|
||||
assert isinstance(orig_value, Module), (
|
||||
f"Expected Module but got {type(orig_value)}"
|
||||
)
|
||||
module_type = sub_concrete_type.jit_type
|
||||
if isinstance(module_type, torch._C.InterfaceType):
|
||||
# use the interface inference rule to compile the module
|
||||
|
@ -318,10 +318,10 @@ class ScriptMeta(type):
|
||||
else:
|
||||
return infer_methods_to_compile(module)
|
||||
|
||||
self.__dict__[
|
||||
"_actual_script_module"
|
||||
] = torch.jit._recursive.create_script_module(
|
||||
self, make_stubs, share_types=not added_methods_in_init
|
||||
self.__dict__["_actual_script_module"] = (
|
||||
torch.jit._recursive.create_script_module(
|
||||
self, make_stubs, share_types=not added_methods_in_init
|
||||
)
|
||||
)
|
||||
|
||||
# Delete the Python attributes that now shadow the ScriptModule
|
||||
|
@ -280,15 +280,15 @@ def max_pool2d(
|
||||
dilation: list[int],
|
||||
ceil_mode: bool,
|
||||
):
|
||||
assert (
|
||||
len(kernel_size) == 1 or len(kernel_size) == 2
|
||||
), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
|
||||
assert len(kernel_size) == 1 or len(kernel_size) == 2, (
|
||||
"max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
|
||||
)
|
||||
kH = kernel_size[0]
|
||||
kW = kH if len(kernel_size) == 1 else kernel_size[1]
|
||||
|
||||
assert (
|
||||
len(stride) == 0 or len(stride) == 1 or len(stride) == 2
|
||||
), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
|
||||
assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, (
|
||||
"max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
|
||||
)
|
||||
dH = kH if len(stride) == 0 else stride[0]
|
||||
if len(stride) == 0:
|
||||
dW = kW
|
||||
@ -297,15 +297,15 @@ def max_pool2d(
|
||||
else:
|
||||
dW = stride[1]
|
||||
|
||||
assert (
|
||||
len(padding) == 1 or len(padding) == 2
|
||||
), "max_pool2d: padding must either be a single int, or a tuple of two ints"
|
||||
assert len(padding) == 1 or len(padding) == 2, (
|
||||
"max_pool2d: padding must either be a single int, or a tuple of two ints"
|
||||
)
|
||||
padH = padding[0]
|
||||
padW = padH if len(padding) == 1 else padding[1]
|
||||
|
||||
assert (
|
||||
len(dilation) == 1 or len(dilation) == 2
|
||||
), "max_pool2d: dilation must be either a single int, or a tuple of two ints"
|
||||
assert len(dilation) == 1 or len(dilation) == 2, (
|
||||
"max_pool2d: dilation must be either a single int, or a tuple of two ints"
|
||||
)
|
||||
dilationH = dilation[0]
|
||||
dilationW = dilationH if len(dilation) == 1 else dilation[1]
|
||||
|
||||
@ -367,17 +367,17 @@ def upsample_nearest2d(
|
||||
assert 0, "Either output_size or scale_factors must be presented"
|
||||
|
||||
if output_size is not None:
|
||||
assert (
|
||||
scale_factors is None
|
||||
), "Must specify exactly one of output_size and scale_factors"
|
||||
assert scale_factors is None, (
|
||||
"Must specify exactly one of output_size and scale_factors"
|
||||
)
|
||||
assert len(output_size) == 2
|
||||
out.append(output_size[0])
|
||||
out.append(output_size[1])
|
||||
|
||||
if scale_factors is not None:
|
||||
assert (
|
||||
output_size is None
|
||||
), "Must specify exactly one of output_size and scale_factors"
|
||||
assert output_size is None, (
|
||||
"Must specify exactly one of output_size and scale_factors"
|
||||
)
|
||||
assert len(scale_factors) == 2
|
||||
out.append(int(input[2] * scale_factors[0]))
|
||||
out.append(int(input[3] * scale_factors[1]))
|
||||
@ -540,9 +540,9 @@ def check_cat_shape_except_dim(
|
||||
assert first_dims == second_dims, "Tensors must have same number of dimensions"
|
||||
for dim in range(0, first_dims):
|
||||
if dim != dimension:
|
||||
assert (
|
||||
first[dim] == second[dim]
|
||||
), "Sizes of tensors must match except in dimension"
|
||||
assert first[dim] == second[dim], (
|
||||
"Sizes of tensors must match except in dimension"
|
||||
)
|
||||
|
||||
|
||||
def cat(tensors: list[list[int]], dim: int):
|
||||
@ -1088,9 +1088,9 @@ def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
|
||||
if len(self) == 0:
|
||||
result: list[int] = []
|
||||
else:
|
||||
assert (
|
||||
k <= self[dim]
|
||||
), f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
|
||||
assert k <= self[dim], (
|
||||
f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
|
||||
)
|
||||
result = _copy(self)
|
||||
result[dim] = k
|
||||
return result, result
|
||||
|
@ -1205,7 +1205,10 @@ def trace_module(
|
||||
|
||||
# Trace specific methods on a module (specified in `inputs`), constructs
|
||||
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
|
||||
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight}
|
||||
inputs = {
|
||||
"forward": example_forward_input,
|
||||
"weighted_kernel_sum": example_weight,
|
||||
}
|
||||
module = torch.jit.trace_module(n, inputs)
|
||||
|
||||
"""
|
||||
|
@ -309,14 +309,14 @@ defined as ``prod(x[:i])``.""",
|
||||
operation_args, operation_kwargs = args_and_kwargs[func.__name__]
|
||||
arg_declarations = [
|
||||
"\n ".join(
|
||||
argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines()
|
||||
argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines()
|
||||
)
|
||||
for a in operation_args
|
||||
]
|
||||
kwarg_declarations = [
|
||||
"\n ".join(
|
||||
argument_declarations.get(
|
||||
a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.'
|
||||
a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD."
|
||||
)
|
||||
.format(default=a.split("=", 1)[1])
|
||||
.splitlines()
|
||||
@ -745,9 +745,9 @@ def _sparse_csr_segment_reduction_helper(
|
||||
) -> Tensor:
|
||||
# Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
|
||||
# FIXME: when dense dimensions are implemented for CSR tensors
|
||||
assert (
|
||||
keepdim
|
||||
), "reduction operations on CSR tensors with keepdim=False is unsupported"
|
||||
assert keepdim, (
|
||||
"reduction operations on CSR tensors with keepdim=False is unsupported"
|
||||
)
|
||||
reduce = op.__name__
|
||||
valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
|
||||
if reduce not in valid_reductions:
|
||||
@ -781,9 +781,9 @@ def _sparse_csr_segment_reduction_helper(
|
||||
)
|
||||
new_shape = [1, mask_input.size(1)]
|
||||
else:
|
||||
assert (
|
||||
dims[0] == 1
|
||||
), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
|
||||
assert dims[0] == 1, (
|
||||
"Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
|
||||
)
|
||||
# all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
|
||||
# except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
|
||||
new_crow_indices = torch.cat(
|
||||
@ -1598,9 +1598,9 @@ def _std_var(
|
||||
mask: Optional[Tensor],
|
||||
take_sqrt: Optional[bool],
|
||||
) -> Tensor:
|
||||
assert (
|
||||
unbiased is None or correction_opt is None
|
||||
), "Only one of unbiased and correction may be given"
|
||||
assert unbiased is None or correction_opt is None, (
|
||||
"Only one of unbiased and correction may be given"
|
||||
)
|
||||
correction = 1.0
|
||||
if unbiased is not None:
|
||||
correction = 1.0 if unbiased else 0.0
|
||||
@ -1636,7 +1636,11 @@ def _std_var(
|
||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||||
else:
|
||||
total = sum(
|
||||
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined]
|
||||
x * x.conj(),
|
||||
dim,
|
||||
keepdim=keepdim,
|
||||
dtype=compute_dtype,
|
||||
mask=inmask, # type: ignore[possibly-undefined]
|
||||
)
|
||||
if not keepdim:
|
||||
count = count.reshape(total.shape)
|
||||
|
@ -25,7 +25,7 @@ def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torch.masked import MaskedTensor
|
||||
>>> data = torch.arange(6).reshape(2,3)
|
||||
>>> data = torch.arange(6).reshape(2, 3)
|
||||
>>> mask = torch.tensor([[True, False, False], [True, True, False]])
|
||||
>>> mt = MaskedTensor(data, mask)
|
||||
>>> is_masked_tensor(mt)
|
||||
|
@ -5,6 +5,7 @@ Metal is Apple's API for programming metal GPU (graphics processor unit). Using
|
||||
performance can be achieved, by running work on the metal GPU(s).
|
||||
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
@ -198,7 +198,7 @@ def snapshot() -> dict[str, Any]:
|
||||
|
||||
|
||||
def attach_out_of_memory_observer(
|
||||
observer: Callable[[int, int, int, int], None]
|
||||
observer: Callable[[int, int, int, int], None],
|
||||
) -> None:
|
||||
r"""Attach an out-of-memory observer to MTIA memory allocator"""
|
||||
torch._C._mtia_attachOutOfMemoryObserver(observer)
|
||||
|
@ -14,6 +14,7 @@ memory.
|
||||
Because of the similarity of APIs we do not document most of this package
|
||||
contents, and we recommend referring to very good docs of the original module.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
|
||||
|
Reference in New Issue
Block a user