Compare commits

..

6 Commits

Author SHA1 Message Date
3587038f9b Automated submodule update: kineto 2025-11-04 11:45:03 -08:00
52ea135f77 [BE] Delete Python-3.9 stdlib definitions from torch.package (#166768)
And simplify the entire function to just assert and return

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166768
Approved by: https://github.com/cyyever, https://github.com/atalman
2025-11-04 19:33:14 +00:00
a5f3035aaf More pyrefly local errors (#166976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976
Approved by: https://github.com/maggiemoss, https://github.com/Skylion007
2025-11-04 18:51:35 +00:00
1d3f5e19da [cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)
Fix and regression test for https://github.com/pytorch/pytorch/issues/165801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/Skylion007, https://github.com/drisspg

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
2025-11-04 18:46:43 +00:00
496277a8ff [ROCm][CI] Lower runner check gpu count for distributed jobs (#166961)
This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information:
https://github.com/pytorch/pytorch/issues/166866

It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961
Approved by: https://github.com/jeffdaily
2025-11-04 18:44:21 +00:00
53f75cd5ba Fixed some syntax errors in SECURITY.md file. (#166718)
Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718
Approved by: https://github.com/mikaylagawarecki
2025-11-04 18:18:38 +00:00
11 changed files with 32 additions and 247 deletions

View File

@ -129,7 +129,7 @@ function install_129 {
} }
function install_128 { function install_128 {
CUDNN_VERSION=9.8.0.87 CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container # install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -272,6 +272,18 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}") print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]: if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}") print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -97,8 +97,8 @@ jobs:
shell: bash shell: bash
run: | run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 4 ]]; then if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs" echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
exit 1 exit 1
fi fi

View File

@ -1,7 +1,7 @@
# Security Policy # Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability) - [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using Pytorch Securely**](#using-pytorch-securely) - [**Using PyTorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models) - [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models) - [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs) - [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles) - [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues ## Reporting Security Issues
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch. Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported: Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat https://www.facebook.com/whitehat
## Using Pytorch Securely ## Using PyTorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). **PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models ### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources]. Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. **Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs. Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models ### TorchScript models
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction ### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy ### Data privacy
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: **Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment) - Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits). - If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
### Using distributed features ### Using distributed features

View File

@ -179,9 +179,6 @@ def aot_stage1_graph_capture(
) )
) )
print(f"in aot_stage1_graph_capture. maybe_subclass_meta.fw_metadata.static_input_indices:{maybe_subclass_meta.fw_metadata.static_input_indices if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None else None}")
print(f"in aot_stage1_graph_capture. aot_state.fw_metadata.static_input_indices:{aot_state.fw_metadata.static_input_indices}")
return AOTGraphCapture( return AOTGraphCapture(
wrappers=wrappers, wrappers=wrappers,
graph_module=graph, graph_module=graph,

View File

@ -498,6 +498,7 @@ def generate_ttir(
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation(*codegen_args) codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map() module_map = backend.get_module_map()
# pyrefly: ignore[missing-argument,bad-argument-type]
ttir_module = src.make_ir(options, codegen_fns, module_map, context) ttir_module = src.make_ir(options, codegen_fns, module_map, context)
else: else:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []

View File

@ -2318,7 +2318,7 @@ def compile_fx_forward(
# force the outputs of invoke_subgraph subgraph to follow the # force the outputs of invoke_subgraph subgraph to follow the
# original strides # original strides
_recursive_record_user_visible_output_idxs(gm) _recursive_record_user_visible_output_idxs(gm)
print(f"in compile_fx_foward. static_input_idxs:{get_static_input_idxs(fixed)}")
return inner_compile( return inner_compile(
gm, gm,
example_inputs, example_inputs,

View File

@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None):
"nvidia-ml-py does not seem to be installed or it can't be imported." "nvidia-ml-py does not seem to be installed or it can't be imported."
# pyrefly: ignore [invalid-inheritance] # pyrefly: ignore [invalid-inheritance]
) from _PYNVML_ERR ) from _PYNVML_ERR
# pyrefly: ignore [import-error] # pyrefly: ignore [import-error,missing-module-attribute]
from pynvml import NVMLError_DriverNotLoaded from pynvml import NVMLError_DriverNotLoaded
try: try:

View File

@ -828,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
import pynvml # type: ignore[import] import pynvml # type: ignore[import]
except ModuleNotFoundError: except ModuleNotFoundError:
return "pynvml module not found, please install nvidia-ml-py" return "pynvml module not found, please install nvidia-ml-py"
# pyrefly: ignore [import-error] # pyrefly: ignore [import-error,missing-module-attribute]
from pynvml import NVMLError_DriverNotLoaded from pynvml import NVMLError_DriverNotLoaded
try: try:

View File

@ -17,230 +17,5 @@ def is_stdlib_module(module: str) -> bool:
def _get_stdlib_modules(): def _get_stdlib_modules():
if sys.version_info.major == 3: # noqa: UP036 assert sys.version_info >= (3, 10)
if sys.version_info.minor == 9: return sys.stdlib_module_names
return stdlib3_9
if sys.version_info.minor >= 10: # noqa: YTT204
return sys.stdlib_module_names # type: ignore[attr-defined]
elif sys.version_info.major > 3: # noqa: UP036
return sys.stdlib_module_names # type: ignore[attr-defined]
raise RuntimeError(f"Unsupported Python version: {sys.version_info}")
stdlib3_9 = {
"_thread",
"abc",
"aifc",
"argparse",
"array",
"ast",
"asynchat",
"asyncio",
"asyncore",
"atexit",
"audioop",
"base64",
"bdb",
"binascii",
"binhex",
"bisect",
"builtins",
"bz2",
"cProfile",
"calendar",
"cgi",
"cgitb",
"chunk",
"cmath",
"cmd",
"code",
"codecs",
"codeop",
"collections",
"colorsys",
"compileall",
"concurrent",
"configparser",
"contextlib",
"contextvars",
"copy",
"copyreg",
"crypt",
"csv",
"ctypes",
"curses",
"dataclasses",
"datetime",
"dbm",
"decimal",
"difflib",
"dis",
"distutils",
"doctest",
"email",
"encodings",
"ensurepip",
"enum",
"errno",
"faulthandler",
"fcntl",
"filecmp",
"fileinput",
"fnmatch",
"formatter",
"fractions",
"ftplib",
"functools",
"gc",
"getopt",
"getpass",
"gettext",
"glob",
"graphlib",
"grp",
"gzip",
"hashlib",
"heapq",
"hmac",
"html",
"http",
"imaplib",
"imghdr",
"imp",
"importlib",
"inspect",
"io",
"ipaddress",
"itertools",
"json",
"keyword",
"lib2to3",
"linecache",
"locale",
"logging",
"lzma",
"mailbox",
"mailcap",
"marshal",
"math",
"mimetypes",
"mmap",
"modulefinder",
"msilib",
"msvcrt",
"multiprocessing",
"netrc",
"nis",
"nntplib",
"ntpath",
"numbers",
"operator",
"optparse",
"os",
"ossaudiodev",
"parser",
"pathlib",
"pdb",
"pickle",
"pickletools",
"pipes",
"pkgutil",
"platform",
"plistlib",
"poplib",
"posix",
"posixpath",
"pprint",
"profile",
"pstats",
"pty",
"pwd",
"py_compile",
"pyclbr",
"pydoc",
"queue",
"quopri",
"random",
"re",
"readline",
"reprlib",
"resource",
"rlcompleter",
"runpy",
"sched",
"secrets",
"select",
"selectors",
"shelve",
"shlex",
"shutil",
"signal",
"site",
"smtpd",
"smtplib",
"sndhdr",
"socket",
"socketserver",
"spwd",
"sqlite3",
"sre",
"sre_compile",
"sre_constants",
"sre_parse",
"ssl",
"stat",
"statistics",
"string",
"stringprep",
"struct",
"subprocess",
"sunau",
"symbol",
"symtable",
"sys",
"sysconfig",
"syslog",
"tabnanny",
"tarfile",
"telnetlib",
"tempfile",
"termios",
"test",
"textwrap",
"threading",
"time",
"timeit",
"tkinter",
"token",
"tokenize",
"trace",
"traceback",
"tracemalloc",
"tty",
"turtle",
"turtledemo",
"types",
"typing",
"unicodedata",
"unittest",
"urllib",
"uu",
"uuid",
"venv",
"warnings",
"wave",
"weakref",
"webbrowser",
"winreg",
"winsound",
"wsgiref",
"xdrlib",
"xml",
"xmlrpc",
"zipapp",
"zipfile",
"zipimport",
"zlib",
"zoneinfo",
}