Compare commits

...

9 Commits

Author SHA1 Message Date
74681719c5 Update on "[dynamo] unimplemented -> unimplemented_v2 for the rest of variables/misc.py"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-04 19:36:07 -08:00
fefe33b058 Update on "[dynamo] unimplemented -> unimplemented_v2 for the rest of variables/misc.py"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-04 15:42:05 -08:00
8726fe583e Update on "[dynamo] unimplemented -> unimplemented_v2 for the rest of variables/misc.py"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-04 12:41:38 -08:00
34db96eb24 [dynamo] unimplemented -> unimplemented_v2 for the rest of variables/misc.py
[ghstack-poisoned]
2025-11-04 12:38:14 -08:00
52ea135f77 [BE] Delete Python-3.9 stdlib definitions from torch.package (#166768)
And simplify the entire function to just assert and return

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166768
Approved by: https://github.com/cyyever, https://github.com/atalman
2025-11-04 19:33:14 +00:00
a5f3035aaf More pyrefly local errors (#166976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976
Approved by: https://github.com/maggiemoss, https://github.com/Skylion007
2025-11-04 18:51:35 +00:00
1d3f5e19da [cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)
Fix and regression test for https://github.com/pytorch/pytorch/issues/165801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/Skylion007, https://github.com/drisspg

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
2025-11-04 18:46:43 +00:00
496277a8ff [ROCm][CI] Lower runner check gpu count for distributed jobs (#166961)
This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information:
https://github.com/pytorch/pytorch/issues/166866

It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961
Approved by: https://github.com/jeffdaily
2025-11-04 18:44:21 +00:00
53f75cd5ba Fixed some syntax errors in SECURITY.md file. (#166718)
Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718
Approved by: https://github.com/mikaylagawarecki
2025-11-04 18:18:38 +00:00
11 changed files with 255 additions and 271 deletions

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.8.0.87
CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -272,6 +272,18 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -97,8 +97,8 @@ jobs:
shell: bash
run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 4 ]]; then
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
exit 1
fi

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
### Using distributed features

View File

@ -67,7 +67,7 @@ class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(counters["graph_break"]), 0)
else:
self.assertIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertGreater(len(counters["graph_break"]), 0)
class ReorderLogsTests(torch._dynamo.test_case.TestCase):

View File

@ -2937,5 +2937,127 @@
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0288": [
{
"Gb_type": "unsupported method call on `typing` variable",
"Context": "typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}",
"Explanation": "`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.",
"Hints": [
"Avoid calling the {name} method on {self.value}.",
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0289": [
{
"Gb_type": "attempted to trace numpy.* function as a method",
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
"Explanation": "Tracing numpy.* functions as methods is not supported.",
"Hints": [
"This graph break may be difficult to debug. Please report an issue to PyTorch for assistance."
]
}
],
"GB0290": [
{
"Gb_type": "logging.Logger method not supported for non-export cases",
"Context": "method: {self.value}.{name}, args: {args}, kwargs: {kwargs}",
"Explanation": "logging.Logger methods are not supported for non-export cases.",
"Hints": [
"Add the logging method to `torch._dynamo.config.ignore_logger_methods."
]
}
],
"GB0291": [
{
"Gb_type": "constant-like method call with unsupported return type",
"Context": "{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}",
"Explanation": "Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0292": [
{
"Gb_type": "attempted to trace numpy function with config.trace_numpy=False",
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
"Explanation": "Attempted to trace numpy function {self.value} while `torch._dynamo.config.trace_numpy` was set to False.",
"Hints": [
"Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions."
]
}
],
"GB0293": [
{
"Gb_type": "attempted to trace numpy function unsupported by PyTorch",
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
"Explanation": "Can't find numpy numpy function {self.value} in torch._numpy.",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
],
"GB0294": [
{
"Gb_type": "cannot reconstruct NullVariable in Python < 3.11",
"Context": "",
"Explanation": "Attempted to generate PUSH_NULL instruction in Python < 3.11; where this instruction does not exist.",
"Hints": [
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
],
"GB0295": [
{
"Gb_type": "attempted to reorder a debugging function that can't actually be reordered",
"Context": "fn: {self.value}, args: {args}, kwargs: {kwargs}",
"Explanation": "`torch.compile` can only reorder functions where the arguments are Tensors, constants, or string formatters.",
"Hints": [
"Avoid calling the logging function {self.value} with args that are not supported."
]
}
],
"GB0296": [
{
"Gb_type": "random.Random() with improper arguments",
"Context": "args: {args}, kwargs: {kwargs}",
"Explanation": "random.Random() with > 1 arg or with kwargs is not supported.",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
],
"GB0297": [
{
"Gb_type": "attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True",
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
"Explanation": "Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` is set to True.",
"Hints": [
"Set `torch._dynamo.config.use_numpy_random_stream` to False.",
"Avoid calling {self.value}."
]
}
],
"GB0298": [
{
"Gb_type": "constant-like method call with non-constant args",
"Context": "{self._error_prefix}.{name}(*{args}, **{kwargs})",
"Explanation": "Attempted to call {self._error_prefix}.{name} with non-constant args.",
"Hints": [
"Ensure that the args to the method call are constant (int, str, etc.)."
]
}
],
"GB0299": [
{
"Gb_type": "numpy function that produces a const collection type encountered non-const arguments",
"Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
"Explanation": "numpy function {self.value} that produces a const collection type (e.g. np.dtype, np.iinfo/np.finfo) received arguments that are not constant.",
"Hints": [
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
]
}
]
}

View File

@ -39,7 +39,7 @@ from ..bytecode_transformation import (
create_instruction,
)
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
from ..exc import raise_observed_exception, unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..mutation_guard import unpatched_nn_module_init
from ..source import (
@ -1382,7 +1382,15 @@ class TypingVariable(VariableTracker):
if name == "__getitem__" and len(args) == 1:
new_typing = self.value[args[0].as_python_constant()]
return TypingVariable(new_typing)
unimplemented("unsupported method call on typing variable")
unimplemented_v2(
gb_type="unsupported method call on `typing` variable",
context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}",
explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.",
hints=[
f"Avoid calling the {name} method on {self.value}.",
*graph_break_hints.SUPPORTABLE,
],
)
def var_getattr(self, tx: "InstructionTranslator", name: str):
from .builder import SourcelessBuilder, VariableBuilder
@ -1493,16 +1501,28 @@ class NumpyVariable(VariableTracker):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if not config.trace_numpy:
unimplemented(f"numpy.{self.value}()")
unimplemented_v2(
gb_type="attempted to trace numpy function with config.trace_numpy=False",
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
explanation=f"Attempted to trace numpy function {self.value} "
"while `torch._dynamo.config.trace_numpy` was set to False.",
hints=[
"Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions.",
],
)
from ..utils import numpy_to_tensor_wrapper
from .tensor import NumpyNdarrayVariable
func = get_np_to_tnp_map().get(self.value)
if func is None:
unimplemented(
f"Can't find numpy function {self.value} in torch._numpy. "
" Please file an issue to request support for this function."
unimplemented_v2(
gb_type="attempted to trace numpy function unsupported by PyTorch",
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
# We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
@ -1516,20 +1536,32 @@ class NumpyVariable(VariableTracker):
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
)
except NotImplementedError:
unimplemented(
f"{self.value.__name__} with non-const args: {args} {kwargs}"
except AsPythonConstantNotImplementedError:
unimplemented_v2(
gb_type="numpy function that produces a const collection type encountered non-const arguments",
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
explanation=f"numpy function {self.value} that produces a const collection type "
"(e.g. np.dtype, np.iinfo/np.finfo) "
"received arguments that are not constant.",
hints=[
*graph_break_hints.USER_ERROR,
],
)
else:
if (
func.__module__ == "torch._numpy.random"
and config.use_numpy_random_stream
):
msg = f"delegate '{func.__qualname__}' to NumPy itself via "
msg += (
f"config.use_numpy_random_stream={config.use_numpy_random_stream}"
unimplemented_v2(
gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True",
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})",
explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` "
"is set to True.",
hints=[
"Set `torch._dynamo.config.use_numpy_random_stream` to False.",
f"Avoid calling {self.value}.",
],
)
unimplemented(msg)
args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
@ -1559,7 +1591,14 @@ class NumpyVariable(VariableTracker):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented("numpy")
unimplemented_v2(
gb_type="attempted to trace numpy.* function as a method",
context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}",
explanation="Tracing numpy.* functions as methods is not supported.",
hints=[
*graph_break_hints.DIFFICULT,
],
)
def as_python_constant(self):
return self.value
@ -1584,7 +1623,15 @@ class NullVariable(VariableTracker):
def reconstruct(self, codegen: "PyCodegen"):
if sys.version_info < (3, 11):
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
unimplemented_v2(
gb_type="cannot reconstruct NullVariable in Python < 3.11",
context="",
explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; "
"where this instruction does not exist.",
hints=[
*graph_break_hints.DYNAMO_BUG,
],
)
codegen.append_output(create_instruction("PUSH_NULL"))
@ -1665,9 +1712,14 @@ class DebuggingVariable(VariableTracker):
return
if not self.can_reorder_logs(self.value, args, kwargs):
unimplemented(
f"Reordering debugging function {self.value} "
f"with inputs {args} {kwargs} is not yet implemented."
unimplemented_v2(
gb_type="attempted to reorder a debugging function that can't actually be reordered",
context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}",
explanation="`torch.compile` can only reorder functions where the arguments "
"are Tensors, constants, or string formatters.",
hints=[
f"Avoid calling the logging function {self.value} with args that are not supported.",
],
)
tx.debug_locals.append((self, list(args)))
@ -1719,10 +1771,13 @@ class LoggingLoggerVariable(VariableTracker):
function = getattr(method, "__func__", None)
if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
return variables.ConstantVariable.create(None)
unimplemented(
"Logger not supported for non-export cases. "
"To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by adding logging methods to config.ignore_logger_methods"
unimplemented_v2(
gb_type="logging.Logger method not supported for non-export cases",
context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}",
explanation="logging.Logger methods are not supported for non-export cases.",
hints=[
"Add the logging method to `torch._dynamo.config.ignore_logger_methods.",
],
)
@ -1759,7 +1814,14 @@ class ConstantLikeVariable(VariableTracker):
cargs = [x.as_python_constant() for x in args]
ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
except NotImplementedError:
unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
unimplemented_v2(
gb_type="constant-like method call with non-constant args",
context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})",
explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.",
hints=[
"Ensure that the args to the method call are constant (int, str, etc.).",
],
)
result = getattr(self.value, name)(*cargs, **ckwargs)
@ -1768,7 +1830,14 @@ class ConstantLikeVariable(VariableTracker):
if isinstance(result, re.Match):
return ConstantRegexMatchVariable(result)
unimplemented(f"{self._error_prefix}.{name}() -> {result}")
unimplemented_v2(
gb_type="constant-like method call with unsupported return type",
context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}",
explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
result = getattr(self.value, name)
@ -1831,10 +1900,15 @@ class RandomClassVariable(VariableTracker):
super().__init__(**kwargs)
def call_function(self, tx: "InstructionTranslator", args, kwargs):
if len(args) > 1:
unimplemented("random.Random() with > 1 arg")
elif kwargs:
unimplemented("random.Random() with kwargs")
if len(args) > 1 or kwargs:
unimplemented_v2(
gb_type="random.Random() with improper arguments",
context=f"args: {args}, kwargs: {kwargs}",
explanation="random.Random() with > 1 arg or with kwargs is not supported.",
hints=[
*graph_break_hints.USER_ERROR,
],
)
seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0]
return RandomVariable(
seed=seed, mutation_type=variables.base.ValueMutationNew()

View File

@ -498,6 +498,7 @@ def generate_ttir(
# pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map()
# pyrefly: ignore[missing-argument,bad-argument-type]
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
else:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []

View File

@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None):
"nvidia-ml-py does not seem to be installed or it can't be imported."
# pyrefly: ignore [invalid-inheritance]
) from _PYNVML_ERR
# pyrefly: ignore [import-error]
# pyrefly: ignore [import-error,missing-module-attribute]
from pynvml import NVMLError_DriverNotLoaded
try:

View File

@ -828,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install nvidia-ml-py"
# pyrefly: ignore [import-error]
# pyrefly: ignore [import-error,missing-module-attribute]
from pynvml import NVMLError_DriverNotLoaded
try:

View File

@ -17,230 +17,5 @@ def is_stdlib_module(module: str) -> bool:
def _get_stdlib_modules():
if sys.version_info.major == 3: # noqa: UP036
if sys.version_info.minor == 9:
return stdlib3_9
if sys.version_info.minor >= 10: # noqa: YTT204
return sys.stdlib_module_names # type: ignore[attr-defined]
elif sys.version_info.major > 3: # noqa: UP036
return sys.stdlib_module_names # type: ignore[attr-defined]
raise RuntimeError(f"Unsupported Python version: {sys.version_info}")
stdlib3_9 = {
"_thread",
"abc",
"aifc",
"argparse",
"array",
"ast",
"asynchat",
"asyncio",
"asyncore",
"atexit",
"audioop",
"base64",
"bdb",
"binascii",
"binhex",
"bisect",
"builtins",
"bz2",
"cProfile",
"calendar",
"cgi",
"cgitb",
"chunk",
"cmath",
"cmd",
"code",
"codecs",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"crypt",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"distutils",
"doctest",
"email",
"encodings",
"ensurepip",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"formatter",
"fractions",
"ftplib",
"functools",
"gc",
"getopt",
"getpass",
"gettext",
"glob",
"graphlib",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"imaplib",
"imghdr",
"imp",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"msilib",
"msvcrt",
"multiprocessing",
"netrc",
"nis",
"nntplib",
"ntpath",
"numbers",
"operator",
"optparse",
"os",
"ossaudiodev",
"parser",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posix",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtpd",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"spwd",
"sqlite3",
"sre",
"sre_compile",
"sre_constants",
"sre_parse",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"sunau",
"symbol",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"telnetlib",
"tempfile",
"termios",
"test",
"textwrap",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"turtledemo",
"types",
"typing",
"unicodedata",
"unittest",
"urllib",
"uu",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"winreg",
"winsound",
"wsgiref",
"xdrlib",
"xml",
"xmlrpc",
"zipapp",
"zipfile",
"zipimport",
"zlib",
"zoneinfo",
}
assert sys.version_info >= (3, 10)
return sys.stdlib_module_names