mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Compare commits
1 Commits
cpp-docs-d
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 935acff3ce |
@ -1,11 +1,15 @@
|
||||
sphinx==7.2.6
|
||||
sphinx==5.3.0
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 7.2.6
|
||||
#Pinned versions: 5.3.0
|
||||
|
||||
pytorch_sphinx_theme2==0.2.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.2.0
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.36.0
|
||||
breathe==4.34.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.36.0
|
||||
#Pinned versions: 4.34.0
|
||||
|
||||
exhale==0.3.7
|
||||
exhale==0.2.3
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.3.7
|
||||
#Pinned versions: 0.2.3
|
||||
|
||||
docutils==0.20
|
||||
docutils==0.16
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.20
|
||||
#Pinned versions: 0.16
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -52,13 +56,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==1.3.0
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 1.3.0
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.6.1
|
||||
sphinx-design==0.4.0
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==4.0.1
|
||||
myst-parser==0.18.1
|
||||
|
||||
@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
|
||||
|
||||
make coverage
|
||||
# Now we have the coverage report, we need to make sure it is empty.
|
||||
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
|
||||
# showing the undocumented count in the third column.
|
||||
# Example: | TOTAL | 99.83% | 2 |
|
||||
# Count the number of lines in the file and turn that number into a variable
|
||||
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
|
||||
# Skip the report header by subtracting 2: the header will be output even if
|
||||
# there are no undocumented items.
|
||||
#
|
||||
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
|
||||
# be documented then removed from there.
|
||||
|
||||
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
|
||||
# The table format is: | Module | Coverage | Undocumented |
|
||||
# Extract the third column (undocumented count) from the TOTAL row
|
||||
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
|
||||
|
||||
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
|
||||
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
|
||||
undocumented=$((lines - 2))
|
||||
if [ $undocumented -lt 0 ]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ "$undocumented" -gt 0 ]; then
|
||||
set +x # Disable command echoing for cleaner output
|
||||
echo ""
|
||||
echo "====================="
|
||||
echo "UNDOCUMENTED OBJECTS:"
|
||||
echo "====================="
|
||||
echo ""
|
||||
# Find the line number of the TOTAL row and print only what comes after it
|
||||
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
|
||||
if [ -n "$total_line" ]; then
|
||||
# Print only the detailed list (skip the statistics table)
|
||||
tail -n +$((total_line + 2)) build/coverage/python.txt
|
||||
else
|
||||
# Fallback to showing entire file if TOTAL line not found
|
||||
cat build/coverage/python.txt
|
||||
fi
|
||||
echo ""
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
cat build/coverage/python.txt
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
|
||||
set -x # Re-enable command echoing
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
||||
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||
- [Untrusted models](#untrusted-models)
|
||||
- [TorchScript models](#torchscript-models)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
@ -10,28 +10,28 @@
|
||||
- [**CI/CD security principles**](#cicd-security-principles)
|
||||
## Reporting Security Issues
|
||||
|
||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
||||
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||
|
||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
|
||||
## Using PyTorch Securely
|
||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
## Using Pytorch Securely
|
||||
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
|
||||
### Untrusted models
|
||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||
|
||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
|
||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||
|
||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### TorchScript models
|
||||
|
||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
||||
|
||||
### Data privacy
|
||||
|
||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
||||
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||
|
||||
### Using distributed features
|
||||
|
||||
|
||||
@ -206,43 +206,19 @@ templates_path = [
|
||||
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
|
||||
]
|
||||
# TODO: document these and remove them from here.
|
||||
# Fixes the duplicated
|
||||
autosummary_filename_map = {
|
||||
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
|
||||
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
|
||||
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
|
||||
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
|
||||
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
|
||||
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
|
||||
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
|
||||
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
|
||||
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
|
||||
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
|
||||
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
|
||||
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
|
||||
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
|
||||
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
|
||||
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
|
||||
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
|
||||
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
|
||||
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
|
||||
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
|
||||
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
|
||||
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
|
||||
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
|
||||
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
|
||||
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
|
||||
"torch.mtia.stream": "torch.mtia.stream_function",
|
||||
"torch.mtia.Stream": "torch.mtia.Stream_class",
|
||||
"torch.cpu.stream": "torch.cpu.stream_function",
|
||||
"torch.cpu.Stream": "torch.cpu.Stream_class",
|
||||
"torch.cuda.stream": "torch.cuda.stream_function",
|
||||
"torch.cuda.Stream": "torch.cuda.Stream_class",
|
||||
"torch.xpu.stream": "torch.xpu.stream_function",
|
||||
"torch.xpu.Stream": "torch.xpu.Stream_class",
|
||||
}
|
||||
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
# torch.cuda._sanitizer
|
||||
"zip_arguments",
|
||||
"zip_by_key",
|
||||
# torch.distributed.autograd
|
||||
"is_available",
|
||||
# torch.distributed.checkpoint.state_dict
|
||||
"gc_context",
|
||||
# torch.distributed.elastic.events
|
||||
"record_rdzv_event",
|
||||
# torch.distributed.elastic.metrics
|
||||
"initialize_metrics",
|
||||
# torch.distributed.elastic.rendezvous.registry
|
||||
@ -3219,11 +3195,6 @@ autodoc_type_aliases = {
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
|
||||
autodoc_default_options = {
|
||||
"exclude-members": "from_bytes, to_bytes",
|
||||
}
|
||||
|
||||
# -- katex javascript in header
|
||||
#
|
||||
# def setup(app):
|
||||
|
||||
@ -253,6 +253,7 @@ regular full-precision tensor.
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
view
|
||||
as_strided
|
||||
|
||||
@ -3317,7 +3317,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "add", [v], {})
|
||||
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def SET_UPDATE(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3326,7 +3326,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def LIST_APPEND(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3634,7 +3634,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg].realize()
|
||||
assert isinstance(obj, ConstDictVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
DICT_UPDATE = DICT_MERGE
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Dictionary-related variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
@ -26,7 +24,7 @@ import inspect
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Hashable as py_Hashable
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
@ -59,11 +57,13 @@ if TYPE_CHECKING:
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def was_instancecheck_override(obj):
|
||||
def was_instancecheck_override(obj: Any) -> bool:
|
||||
return type(obj).__dict__.get("__instancecheck__", False)
|
||||
|
||||
|
||||
def raise_unhashable(arg, tx=None):
|
||||
def raise_unhashable(
|
||||
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
|
||||
) -> None:
|
||||
if tx is None:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
|
||||
)
|
||||
|
||||
|
||||
def is_hashable(x):
|
||||
def is_hashable(x: VariableTracker) -> bool:
|
||||
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
|
||||
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
|
||||
# the underlying value without realizing the VT. Consider updating the
|
||||
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
|
||||
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
|
||||
"""
|
||||
|
||||
def __init__(self, vt) -> None:
|
||||
def __init__(self, vt: VariableTracker) -> None:
|
||||
# We specialize SymNodes
|
||||
vt = specialize_symnode(vt)
|
||||
# TODO Temporarily remove to figure out what keys are we breaking on
|
||||
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
|
||||
self.vt = vt
|
||||
|
||||
@property
|
||||
def underlying_value(self):
|
||||
def underlying_value(self) -> Any:
|
||||
if (
|
||||
isinstance(self.vt, variables.LazyVariableTracker)
|
||||
and not self.vt.is_realized()
|
||||
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
|
||||
elif isinstance(self.vt, variables.FrozenDataClassVariable):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
fields_values = {
|
||||
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
|
||||
k: Hashable(v).underlying_value
|
||||
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
|
||||
}
|
||||
return variables.FrozenDataClassVariable.HashWrapper(
|
||||
self.vt.python_type(), fields_values
|
||||
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
|
||||
# The re module in Python 3.13+ has a dictionary (_cache2) with
|
||||
# an object as key (`class _ZeroSentinel(int): ...`):
|
||||
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
|
||||
return self.vt.value
|
||||
return self.vt.value # type: ignore[attr-defined,union-attr]
|
||||
else:
|
||||
x = self.vt.as_python_constant()
|
||||
return x
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.underlying_value)
|
||||
|
||||
@staticmethod
|
||||
def _eq_impl(a, b):
|
||||
def _eq_impl(a: Any, b: Any) -> bool:
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
type_a, type_b = type(a), type(b)
|
||||
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return a == b
|
||||
|
||||
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
|
||||
type(other)
|
||||
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls=dict,
|
||||
**kwargs,
|
||||
user_cls: type = dict,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# .clone() pass these arguments in kwargs but they're recreated a few
|
||||
# lines below
|
||||
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
|
||||
for x, v in items.items()
|
||||
)
|
||||
|
||||
def make_hashable(key):
|
||||
def make_hashable(
|
||||
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
|
||||
) -> "ConstDictVariable._HashableTracker":
|
||||
return key if isinstance(key, Hashable) else Hashable(key)
|
||||
|
||||
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
|
||||
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
|
||||
# need to reconstruct everything if the dictionary is an intermediate value
|
||||
# or if a pop/delitem was executed
|
||||
self.should_reconstruct_all = not is_from_local_source(self.source)
|
||||
self.should_reconstruct_all = (
|
||||
not is_from_local_source(self.source) if self.source else True
|
||||
)
|
||||
self.original_items = items.copy()
|
||||
self.user_cls = user_cls
|
||||
|
||||
def _get_dict_cls_from_user_cls(self, user_cls):
|
||||
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
|
||||
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
|
||||
|
||||
# avoid executing user code if user_cls is a dict subclass
|
||||
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
|
||||
dict_cls = dict
|
||||
return dict_cls
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> dict[Any, Any]:
|
||||
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
|
||||
+ "}"
|
||||
)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> dict[Any, Any]:
|
||||
return {
|
||||
k.vt.as_python_constant(): v.as_python_constant()
|
||||
for k, v in self.items.items()
|
||||
}
|
||||
|
||||
def keys_as_python_constant(self):
|
||||
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return self.user_cls
|
||||
|
||||
def __contains__(self, vt) -> bool:
|
||||
def __contains__(self, vt: VariableTracker) -> bool:
|
||||
assert isinstance(vt, VariableTracker)
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return (
|
||||
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
|
||||
for key, value in self.items.items()
|
||||
)
|
||||
|
||||
def is_new_item(self, value, other):
|
||||
def is_new_item(
|
||||
self, value: Optional[VariableTracker], other: VariableTracker
|
||||
) -> bool:
|
||||
# compare the id of the realized values if both values are not lazy VTs
|
||||
if value and value.is_realized() and other.is_realized():
|
||||
return id(value.realize()) != id(other.realize())
|
||||
return id(value) != id(other)
|
||||
|
||||
def reconstruct_kvs_into_new_dict(self, codegen):
|
||||
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
|
||||
# Build a dictionary that contains the keys and values.
|
||||
num_args = 0
|
||||
for key, value in self.items.items():
|
||||
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
|
||||
num_args += 1
|
||||
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
if self.user_cls is collections.OrderedDict:
|
||||
# emit `OrderedDict(constructed_dict)`
|
||||
codegen.add_push_null(
|
||||
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def getitem_const_raise_exception_if_absent(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
):
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
raise_observed_exception(KeyError, tx)
|
||||
return self.items[key]
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
msg = f"Dictionary key {arg.value} not found during tracing"
|
||||
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
|
||||
unimplemented_v2(
|
||||
gb_type="key not found in dict",
|
||||
context=f"Key {arg.value}",
|
||||
context=f"Key {arg.value}", # type: ignore[attr-defined]
|
||||
explanation=msg,
|
||||
hints=[
|
||||
"Check if the key exists in the dictionary before accessing it.",
|
||||
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
return self.items[key]
|
||||
|
||||
def maybe_getitem_const(self, arg: VariableTracker):
|
||||
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
return None
|
||||
return self.items[key]
|
||||
|
||||
def realize_key_vt(self, arg: VariableTracker):
|
||||
def realize_key_vt(self, arg: VariableTracker) -> None:
|
||||
# Realize the LazyVT on a particular index
|
||||
assert arg in self
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
|
||||
if isinstance(original_key_vt, variables.LazyVariableTracker):
|
||||
original_key_vt.realize()
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
if self.source:
|
||||
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Key guarding - These are the cases to consider
|
||||
# 1) The dict has been mutated. In this case, we would have already
|
||||
# inserted a DICT_KEYS_MATCH guard, so we can skip.
|
||||
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
|
||||
# we have to insert guards when a dict method is accessed. For this to
|
||||
# be simple, we are conservative and overguard. We skip guard only for
|
||||
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
|
||||
tx, *args, **kwargs
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.update(temp_dict_vt.items)
|
||||
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__getitem__":
|
||||
# Key guarding - Nothing to do. LazyVT for value will take care.
|
||||
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
|
||||
return ConstantVariable.create(len(self.items))
|
||||
elif name == "__setitem__" and self.is_mutable():
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) != 2:
|
||||
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
# missing item, return the default value. Install no DICT_CONTAINS guard.
|
||||
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
|
||||
last = v.value
|
||||
else:
|
||||
raise_args_mismatch(tx, name)
|
||||
k, v = self.items.popitem(last=last)
|
||||
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
k, v = self.items.popitem()
|
||||
|
||||
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
dict_vt = args[0]
|
||||
dict_vt: ConstDictVariable = args[0]
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
|
||||
self.items.update(dict_vt.items) # type: ignore[attr-defined]
|
||||
if has_kwargs:
|
||||
# Handle kwargs
|
||||
kwargs = {
|
||||
kwargs_hashable = {
|
||||
Hashable(ConstantVariable.create(k)): v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs)
|
||||
self.items.update(kwargs_hashable)
|
||||
return ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
contains = args[0] in self
|
||||
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) > 2:
|
||||
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
|
||||
and "last" in kwargs
|
||||
and isinstance(kwargs["last"], ConstantVariable)
|
||||
):
|
||||
last = kwargs.get("last").value
|
||||
last = kwargs.get("last").value # type: ignore[union-attr]
|
||||
|
||||
key = Hashable(args[0])
|
||||
self.items.move_to_end(key, last=last)
|
||||
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return ConstantVariable.create(
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
|
||||
)
|
||||
elif name == "__or__":
|
||||
if len(args) != 1:
|
||||
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
|
||||
if not istype(
|
||||
other, (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
):
|
||||
msg = (
|
||||
err_msg = (
|
||||
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
|
||||
f"and '{other.python_type().__name__}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_observed_exception(TypeError, tx, args=[err_msg])
|
||||
|
||||
# OrderedDict overloads __ror__
|
||||
ts = {self.user_cls, other.user_cls}
|
||||
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
|
||||
user_cls = (
|
||||
collections.OrderedDict
|
||||
if any(issubclass(t, collections.OrderedDict) for t in ts)
|
||||
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
new_dict_vt.items.update(args[0].items)
|
||||
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
|
||||
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
|
||||
return new_dict_vt
|
||||
elif name == "__ior__":
|
||||
self.call_method(tx, "update", args, kwargs)
|
||||
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
# dict not allow setting arbitrary attributes. OrderedDict and
|
||||
# defaultdict allow arbitrary setattr, but not deletion of default attrs
|
||||
if any(
|
||||
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
|
||||
],
|
||||
)
|
||||
|
||||
def clone(self, **kwargs):
|
||||
def clone(self, **kwargs: Any) -> VariableTracker:
|
||||
self.install_dict_keys_match_guard()
|
||||
return super().clone(**kwargs)
|
||||
|
||||
|
||||
class MappingProxyVariable(VariableTracker):
|
||||
# proxies to the original dict_vt
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return types.MappingProxyType
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.dv_dict.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# load types.MappingProxyType
|
||||
if self.source:
|
||||
msg = (
|
||||
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if self.source and tx.output.side_effects.has_existing_dict_mutation():
|
||||
msg = (
|
||||
"A dict has been modified while we have an existing mappingproxy object. "
|
||||
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is types.MappingProxyType:
|
||||
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
@ -900,33 +913,42 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
class NNModuleHooksDictVariable(ConstDictVariable):
|
||||
# Special class to avoid adding any guards on the nn module hook ids.
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultDictVariable(ConstDictVariable):
|
||||
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls: type,
|
||||
default_factory: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, user_cls, **kwargs)
|
||||
assert user_cls is collections.defaultdict
|
||||
self.default_factory = default_factory
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
# Return false for unsupported defaults. This ensures that a bad handler
|
||||
# path is not taken in BuiltinVariable for getitem.
|
||||
if self.default_factory not in [list, tuple, dict] and not self.items:
|
||||
return False
|
||||
return super().is_python_constant()
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
assert self.default_factory is not None
|
||||
return (
|
||||
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_supported_arg(arg):
|
||||
def is_supported_arg(arg: VariableTracker) -> bool:
|
||||
if isinstance(arg, variables.BuiltinVariable):
|
||||
return arg.fn in (list, tuple, dict, set)
|
||||
else:
|
||||
@ -934,11 +956,11 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__getitem__":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
|
||||
@ -951,13 +973,13 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
else:
|
||||
default_var = self.default_factory.call_function(tx, [], {})
|
||||
super().call_method(
|
||||
tx, "__setitem__", (args[0], default_var), kwargs
|
||||
tx, "__setitem__", [args[0], default_var], kwargs
|
||||
)
|
||||
return default_var
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# emit `defaultdict(default_factory, new_dict)`
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -983,40 +1005,48 @@ class SetVariable(ConstDictVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# pyrefly: ignore[bad-assignment]
|
||||
items = dict.fromkeys(items, SetVariable._default_value())
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "set()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return set(self.items.keys())
|
||||
|
||||
@staticmethod
|
||||
def _default_value():
|
||||
def _default_value() -> VariableTracker:
|
||||
# Variable to fill in he keys of the dictionary
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
return {k.vt.as_proxy() for k in self.set_items}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return set
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||
def _fast_set_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
try:
|
||||
res = fn(
|
||||
*[x.as_python_constant() for x in [self, *args]],
|
||||
@ -1026,15 +1056,16 @@ class SetVariable(ConstDictVariable):
|
||||
raise_observed_exception(
|
||||
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||
)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
# We forward the calls to the dictionary model
|
||||
from ..utils import check_constant_args
|
||||
|
||||
@ -1054,10 +1085,10 @@ class SetVariable(ConstDictVariable):
|
||||
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||
|
||||
if name == "__init__":
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.clear()
|
||||
self.items.update(temp_set_vt.items)
|
||||
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "add":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1068,7 +1099,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
name = "__setitem__"
|
||||
args = (args[0], SetVariable._default_value())
|
||||
args = [args[0], SetVariable._default_value()]
|
||||
elif name == "pop":
|
||||
if kwargs or args:
|
||||
raise_args_mismatch(
|
||||
@ -1079,12 +1110,14 @@ class SetVariable(ConstDictVariable):
|
||||
)
|
||||
# Choose an item at random and pop it via the Dict.pop method
|
||||
try:
|
||||
result = self.set_items.pop().vt
|
||||
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
|
||||
except KeyError as e:
|
||||
raise_observed_exception(
|
||||
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
|
||||
)
|
||||
super().call_method(tx, name, (result,), kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
super().call_method(tx, name, [result], kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return result
|
||||
elif name == "isdisjoint":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1206,6 +1239,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
assert m is not None
|
||||
return self.call_method(tx, m, args, kwargs)
|
||||
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
@ -1219,29 +1253,34 @@ class SetVariable(ConstDictVariable):
|
||||
"__ixor__": "symmetric_difference_update",
|
||||
"__isub__": "difference_update",
|
||||
}.get(name)
|
||||
assert m is not None
|
||||
self.call_method(tx, m, args, kwargs)
|
||||
return self
|
||||
elif name == "__eq__":
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(False)
|
||||
r = self.call_method(tx, "symmetric_difference", args, kwargs)
|
||||
return ConstantVariable.create(len(r.set_items) == 0)
|
||||
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
raise RuntimeError("Illegal to getitem on a set")
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
super().install_dict_contains_guard(tx, args)
|
||||
|
||||
|
||||
@ -1249,27 +1288,27 @@ class FrozensetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "frozenset()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return self.items.keys()
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return frozenset
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return frozenset({k.vt.as_python_constant() for k in self.set_items})
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -1282,11 +1321,11 @@ class FrozensetVariable(SetVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
|
||||
elif name == "__init__":
|
||||
@ -1305,7 +1344,7 @@ class FrozensetVariable(SetVariable):
|
||||
"symmetric_difference",
|
||||
):
|
||||
r = super().call_method(tx, name, args, kwargs)
|
||||
return FrozensetVariable(r.items)
|
||||
return FrozensetVariable(r.items) # type: ignore[attr-defined]
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
@ -1313,11 +1352,11 @@ class DictKeySetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "dict_keys([])"
|
||||
else:
|
||||
@ -1327,33 +1366,35 @@ class DictKeySetVariable(SetVariable):
|
||||
+ "])"
|
||||
)
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> Any:
|
||||
return self.items
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return dict.fromkeys(
|
||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
||||
).keys()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -1368,42 +1409,47 @@ class DictViewVariable(VariableTracker):
|
||||
|
||||
kv: Optional[str] = None
|
||||
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert self.kv in ("keys", "values", "items")
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
@property
|
||||
def view_items(self):
|
||||
def view_items(self) -> Any:
|
||||
assert self.kv is not None
|
||||
return getattr(self.dv_dict.items, self.kv)()
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
# Implement in the subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.view_items_vt
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.kv is not None
|
||||
codegen(self.dv_dict)
|
||||
codegen.load_method(self.kv)
|
||||
codegen.call_method(0)
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
assert self.kv is not None
|
||||
if name in self.python_type().__dict__:
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__len__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__":
|
||||
@ -1417,24 +1463,24 @@ class DictKeysVariable(DictViewVariable):
|
||||
kv = "keys"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set[VariableTracker]:
|
||||
return set(self.view_items)
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [x.vt for x in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__contains__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name in (
|
||||
@ -1449,13 +1495,13 @@ class DictKeysVariable(DictViewVariable):
|
||||
):
|
||||
# These methods always returns a set
|
||||
m = getattr(self.set_items, name)
|
||||
r = m(args[0].set_items)
|
||||
r = m(args[0].set_items) # type: ignore[attr-defined]
|
||||
return SetVariable(r)
|
||||
if name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -1465,10 +1511,10 @@ class DictValuesVariable(DictViewVariable):
|
||||
kv = "values"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
return list(self.view_items)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_values
|
||||
|
||||
|
||||
@ -1476,14 +1522,20 @@ class DictItemsVariable(DictViewVariable):
|
||||
kv = "items"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_items
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# TODO(guilhermeleobas): This should actually check if args[0]
|
||||
# implements the mapping protocol.
|
||||
if name == "__eq__":
|
||||
|
||||
Reference in New Issue
Block a user