Compare commits

..

1 Commits

Author SHA1 Message Date
935acff3ce Type dicts.py 2025-11-04 15:44:11 -08:00
7 changed files with 260 additions and 250 deletions

View File

@ -1,11 +1,15 @@
sphinx==7.2.6
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 7.2.6
#Pinned versions: 5.3.0
pytorch_sphinx_theme2==0.2.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.2.0
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.36.0
breathe==4.34.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.36.0
#Pinned versions: 4.34.0
exhale==0.3.7
exhale==0.2.3
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.3.7
#Pinned versions: 0.2.3
docutils==0.20
docutils==0.16
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.20
#Pinned versions: 0.16
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -52,13 +56,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==1.3.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 1.3.0
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.6.1
sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0
myst-parser==4.0.1
myst-parser==0.18.1

View File

@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
make coverage
# Now we have the coverage report, we need to make sure it is empty.
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
# showing the undocumented count in the third column.
# Example: | TOTAL | 99.83% | 2 |
# Count the number of lines in the file and turn that number into a variable
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
# Skip the report header by subtracting 2: the header will be output even if
# there are no undocumented items.
#
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
# be documented then removed from there.
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
# The table format is: | Module | Coverage | Undocumented |
# Extract the third column (undocumented count) from the TOTAL row
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
undocumented=$((lines - 2))
if [ $undocumented -lt 0 ]; then
echo coverage output not found
exit 1
elif [ "$undocumented" -gt 0 ]; then
set +x # Disable command echoing for cleaner output
echo ""
echo "====================="
echo "UNDOCUMENTED OBJECTS:"
echo "====================="
echo ""
# Find the line number of the TOTAL row and print only what comes after it
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
if [ -n "$total_line" ]; then
# Print only the detailed list (skip the statistics table)
tail -n +$((total_line + 2)) build/coverage/python.txt
else
# Fallback to showing entire file if TOTAL line not found
cat build/coverage/python.txt
fi
echo ""
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
cat build/coverage/python.txt
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
set -x # Re-enable command echoing
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1
fi
else

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
### Using distributed features

View File

@ -206,43 +206,19 @@ templates_path = [
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
# torch.cuda._sanitizer
"zip_arguments",
"zip_by_key",
# torch.distributed.autograd
"is_available",
# torch.distributed.checkpoint.state_dict
"gc_context",
# torch.distributed.elastic.events
"record_rdzv_event",
# torch.distributed.elastic.metrics
"initialize_metrics",
# torch.distributed.elastic.rendezvous.registry
@ -3219,11 +3195,6 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,6 +253,7 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -3317,7 +3317,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {})
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3326,7 +3326,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3634,7 +3634,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {})
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
DICT_UPDATE = DICT_MERGE

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -26,7 +24,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING, Union
from torch._subclasses.fake_tensor import is_fake
@ -59,11 +57,13 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj):
def was_instancecheck_override(obj: Any) -> bool:
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(arg, tx=None):
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
)
def is_hashable(x):
def is_hashable(x: VariableTracker) -> bool:
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt) -> None:
def __init__(self, vt: VariableTracker) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self):
def underlying_value(self) -> Any:
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value
return self.vt.value # type: ignore[attr-defined,union-attr]
else:
x = self.vt.as_python_constant()
return x
def __hash__(self):
def __hash__(self) -> int:
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a, b):
def _eq_impl(a: Any, b: Any) -> bool:
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
def __eq__(self, other: object) -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls=dict,
**kwargs,
user_cls: type = dict,
**kwargs: Any,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(key):
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = not is_from_local_source(self.source)
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls):
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self):
def as_proxy(self) -> dict[Any, Any]:
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self):
def debug_repr(self) -> str:
return (
"{"
+ ", ".join(
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self):
def as_python_constant(self) -> dict[Any, Any]:
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self):
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self):
def python_type(self) -> type:
return self.user_cls
def __contains__(self, vt) -> bool:
def __contains__(self, vt: VariableTracker) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(self, value, other):
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen):
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
):
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing"
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}",
context=f"Key {arg.value}", # type: ignore[attr-defined]
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker):
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker):
def realize_key_vt(self, arg: VariableTracker) -> None:
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
else:
k, v = self.items.popitem()
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt = args[0]
dict_vt: ConstDictVariable = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
if has_kwargs:
# Handle kwargs
kwargs = {
kwargs_hashable = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs)
self.items.update(kwargs_hashable)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0])
raise_unhashable(args[0], tx)
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value
last = kwargs.get("last").value # type: ignore[union-attr]
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
)
elif name == "__or__":
if len(args) != 1:
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
msg = (
err_msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[err_msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls}
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs):
def clone(self, **kwargs: Any) -> VariableTracker:
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self):
def python_type(self) -> type:
return types.MappingProxyType
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
# load types.MappingProxyType
if self.source:
msg = (
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -900,33 +913,42 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
def is_python_constant(self):
def is_python_constant(self) -> bool:
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self):
def debug_repr(self) -> str:
assert self.default_factory is not None
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg):
def is_supported_arg(arg: VariableTracker) -> bool:
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -934,11 +956,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -951,13 +973,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", (args[0], default_var), kwargs
tx, "__setitem__", [args[0], default_var], kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen") -> None:
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -983,40 +1005,48 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return set(self.items.keys())
@staticmethod
def _default_value():
def _default_value() -> VariableTracker:
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self):
def as_proxy(self) -> Any:
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self):
def python_type(self) -> type:
return set
def as_python_constant(self):
def as_python_constant(self) -> Any:
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(self, tx, fn, args, kwargs):
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1026,15 +1056,16 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1054,10 +1085,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items)
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1068,7 +1099,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = (args[0], SetVariable._default_value())
args = [args[0], SetVariable._default_value()]
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1079,12 +1110,14 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result = self.set_items.pop().vt
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
super().call_method(tx, name, (result,), kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1206,6 +1239,7 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1219,29 +1253,34 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
super().install_dict_contains_guard(tx, args)
@ -1249,27 +1288,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self):
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
return self.items.keys()
def python_type(self):
def python_type(self) -> type:
return frozenset
def as_python_constant(self):
def as_python_constant(self) -> Any:
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1282,11 +1321,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1305,7 +1344,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return super().call_method(tx, name, args, kwargs)
@ -1313,11 +1352,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
if not self.items:
return "dict_keys([])"
else:
@ -1327,33 +1366,35 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self):
def install_dict_keys_match_guard(self) -> None:
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(self, tx, args):
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self):
def set_items(self) -> Any:
return self.items
def python_type(self):
def python_type(self) -> type:
return dict_keys
def as_python_constant(self):
def as_python_constant(self) -> Any:
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
) -> VariableTracker:
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1368,42 +1409,47 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self):
def view_items(self) -> Any:
assert self.kv is not None
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1417,24 +1463,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self):
def set_items(self) -> set[VariableTracker]:
return set(self.view_items)
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_keys
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1449,13 +1495,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items)
r = m(args[0].set_items) # type: ignore[attr-defined]
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
)
return super().call_method(tx, name, args, kwargs)
@ -1465,10 +1511,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
return list(self.view_items)
def python_type(self):
def python_type(self) -> type:
return dict_values
@ -1476,14 +1522,20 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self):
def view_items_vt(self) -> list[VariableTracker]:
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self):
def python_type(self) -> type:
return dict_items
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":