Compare commits

..

8 Commits

Author SHA1 Message Date
61f59966c7 Update on "[Device Mesh] Add an option to decouple PGs when it comes device mesh save"
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-12 16:26:43 -08:00
ec3befc028 [Device Mesh] Add an option to decouple PGs when it comes device mesh save
[ghstack-poisoned]
2025-11-11 15:47:04 -08:00
306aa9c2a4 Update on "[Device Mesh][ez] Clean up unused parameters and duplicate codes"
While refactoring the code, I found we re-init `_flatten_mapping` and still keep `_flatten_mesh_list ` inside code which is not needed anymore. Let's remove it.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-11 14:49:07 -08:00
fc1469be71 [Device Mesh][ez] Clean up unused parameters and duplicate codes
[ghstack-poisoned]
2025-11-11 14:16:34 -08:00
a5f3035aaf More pyrefly local errors (#166976)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976
Approved by: https://github.com/maggiemoss, https://github.com/Skylion007
2025-11-04 18:51:35 +00:00
1d3f5e19da [cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)
Fix and regression test for https://github.com/pytorch/pytorch/issues/165801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/Skylion007, https://github.com/drisspg

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
2025-11-04 18:46:43 +00:00
496277a8ff [ROCm][CI] Lower runner check gpu count for distributed jobs (#166961)
This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information:
https://github.com/pytorch/pytorch/issues/166866

It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961
Approved by: https://github.com/jeffdaily
2025-11-04 18:44:21 +00:00
53f75cd5ba Fixed some syntax errors in SECURITY.md file. (#166718)
Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718
Approved by: https://github.com/mikaylagawarecki
2025-11-04 18:18:38 +00:00
11 changed files with 152 additions and 29 deletions

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.8.0.87
CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -272,6 +272,18 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -97,8 +97,8 @@ jobs:
shell: bash
run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 4 ]]; then
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
exit 1
fi

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
### Using distributed features

View File

@ -10,9 +10,8 @@ from numpy.testing import assert_array_equal
import torch
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
@ -554,6 +553,31 @@ class DTensorTest(DTensorTestBase):
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
@with_comms
def test_dtensor_save_load_with_mesh_backend_decouple(self):
import io
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
DeviceMesh.decouple_backend_at_save = True
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
reloaded_st._spec.mesh = device_mesh
# We will change this to be not Equal in the following PR.
self.assertEqual(sharded_tensor, reloaded_st)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertFalse(hasattr(reloaded_st._spec.mesh, "_dim_group_names"))
reloaded_st._spec.mesh = device_mesh
self.assertEqual(sharded_tensor, reloaded_st)
DeviceMesh.decouple_backend_at_save = False
@skipIfHpu
@with_comms
@unittest.skipIf(
@ -641,6 +665,7 @@ DTensorTestWithLocalTensor = create_local_tensor_test_class(
# integration
"test_dtensor_save_load",
"test_dtensor_save_load_import",
"test_dtensor_save_load_with_mesh_backend_decouple",
],
)

View File

@ -179,9 +179,6 @@ def aot_stage1_graph_capture(
)
)
print(f"in aot_stage1_graph_capture. maybe_subclass_meta.fw_metadata.static_input_indices:{maybe_subclass_meta.fw_metadata.static_input_indices if maybe_subclass_meta is not None and maybe_subclass_meta.fw_metadata is not None else None}")
print(f"in aot_stage1_graph_capture. aot_state.fw_metadata.static_input_indices:{aot_state.fw_metadata.static_input_indices}")
return AOTGraphCapture(
wrappers=wrappers,
graph_module=graph,

View File

@ -498,6 +498,7 @@ def generate_ttir(
# pyrefly: ignore # missing-attribute
codegen_fns = backend.get_codegen_implementation(*codegen_args)
module_map = backend.get_module_map()
# pyrefly: ignore[missing-argument,bad-argument-type]
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
else:
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []

View File

@ -2318,7 +2318,7 @@ def compile_fx_forward(
# force the outputs of invoke_subgraph subgraph to follow the
# original strides
_recursive_record_user_visible_output_idxs(gm)
print(f"in compile_fx_foward. static_input_idxs:{get_static_input_idxs(fixed)}")
return inner_compile(
gm,
example_inputs,

View File

@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None):
"nvidia-ml-py does not seem to be installed or it can't be imported."
# pyrefly: ignore [invalid-inheritance]
) from _PYNVML_ERR
# pyrefly: ignore [import-error]
# pyrefly: ignore [import-error,missing-module-attribute]
from pynvml import NVMLError_DriverNotLoaded
try:

View File

@ -828,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install nvidia-ml-py"
# pyrefly: ignore [import-error]
# pyrefly: ignore [import-error,missing-module-attribute]
from pynvml import NVMLError_DriverNotLoaded
try:

View File

@ -6,7 +6,7 @@ import threading
import warnings
from collections.abc import Iterator
from itertools import zip_longest
from typing import Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
@ -173,6 +173,9 @@ else:
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
"""
# Flag to specify device save without backend info. This is a temporary variable
# We will remove this flag once we fully deprecate the behavior of save a device mesh with pg names.
decouple_backend_at_save = False
_device_type: str
_rank_map: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
@ -255,14 +258,13 @@ else:
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._flatten_rank_map = tuple(self._rank_map.tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -293,11 +295,6 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
@ -1239,6 +1236,97 @@ else:
res_mesh._dim_group_names = concat_dim_group_name
return res_mesh
def __getstate__(self):
"""
Returns the state of the DeviceMesh as a dictionary for serialization,
which contains all necessary information to reconstruct the DeviceMesh.
"""
state: dict[str, Any] = {
"device_type": self._device_type,
"rank_map": self._rank_map,
"layout": self._layout,
"mesh_dim_names": self._mesh_dim_names,
"thread_id": self._thread_id,
"coordinate_on_dim": getattr(self, "_coordinate_on_dim", None),
}
# Serialize root_mesh if it exists
# To avoid infinite recursion (root -> child -> root), only serialize if this is not the root
if self._root_mesh is not None:
state["root_mesh"] = self._root_mesh.__getstate__()
else:
state["root_mesh"] = None
# Serialize flatten_mapping
flatten_mapping: dict[str, Any] = {}
for mesh_name, mesh in self._flatten_mapping.items():
flatten_mapping[mesh_name] = mesh.__getstate__()
state["flatten_mapping"] = flatten_mapping
if not self.decouple_backend_at_save and hasattr(self, "_dim_group_names"):
logger.warning(
"Save device mesh via torch.save with pg names and will be deprecated in PT 2.11. "
"Users are welcome to use Distributed checkpoint (DCP) or re-create pgs in the same order"
"as the original device mesh."
)
state["dim_group_names"] = self._dim_group_names
return state
def __setstate__(self, state):
"""
Restores the DeviceMesh state from a state dictionary.
"""
required_keys = {
"device_type",
"rank_map",
"layout",
"mesh_dim_names",
"thread_id",
"coordinate_on_dim",
"root_mesh",
"flatten_mapping",
}
missing_keys = required_keys - state.keys()
if missing_keys:
raise ValueError(f"state_dict is missing required keys: {missing_keys}")
# Restore basic attributes
self._device_type = state["device_type"]
self._rank_map = state["rank_map"]
self._layout = state["layout"]
self._mesh_dim_names = state["mesh_dim_names"]
self._thread_id = state["thread_id"]
if state.get("coordinate_on_dim") is not None:
self._coordinate_on_dim = state["coordinate_on_dim"]
# Restore root_mesh if it exists
if state.get("root_mesh") is not None:
# Create a new DeviceMesh for the root mesh
root_mesh = DeviceMesh.__new__(DeviceMesh)
root_mesh.__setstate__(state["root_mesh"])
self._root_mesh = root_mesh
else:
self._root_mesh = None
# Re-initialize internal bookkeeping
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Restore flatten_mapping
self._flatten_mapping = {}
if state.get("flatten_mapping"):
for mesh_name, mesh_state in state["flatten_mapping"].items():
flatten_mesh = DeviceMesh.__new__(DeviceMesh)
flatten_mesh.__setstate__(mesh_state)
self._flatten_mapping[mesh_name] = flatten_mesh
# We don't recommend load from saved pg names, because users need to ensure the same
# order in creating process groups when we save the device mesh.
# This is implicit and error-prone. We will remove this behavior soon.
# What we recommend users to do is to explicitly create PGs and set it to the loaded mesh.
if state.get("dim_group_names"):
self._dim_group_names = state["dim_group_names"]
def _normalize_backend_override(
backend_override: dict[
Union[int, str],