Compare commits

...

15 Commits

Author SHA1 Message Date
ad7db3617e [inductor, 3.14] catch pickle.PicklingError exceptions (#167383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167383
Approved by: https://github.com/aorenste
ghstack dependencies: #167382
2025-11-10 20:52:04 +00:00
5320ca3725 [inductor, 3.14] fix itertools.product pickle error in test_cpu_repro (#167382)
`inductor/test_cpu_cpp_wrapper` was failing since it was attempting to pickle`itertools.product`, and that is no longer picklable in 3.14. We work around by eagerly generating a list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167382
Approved by: https://github.com/atalman, https://github.com/malfet
2025-11-10 20:52:04 +00:00
3e4faca130 [torch.export] Refactor placeholder_naming_pass to reduce CCN (#166600)
Summary: Reduced CCN from 37 to 28 of placeholder_naming_pass method

Test Plan: Existing tests

Differential Revision: D85820388

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166600
Approved by: https://github.com/angelayi
2025-11-10 20:44:18 +00:00
0c2f206ded Typo fix - baddbmm_strategy (#166963)
This is called by registration with decorator, so function not called directly. For clarity, add the "b" for "batch" in function name.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166963
Approved by: https://github.com/janeyx99
2025-11-10 20:35:42 +00:00
6cf21fa331 Fix -ffunction-sections, -fdata-sections not being added on aarch64. (#166407)
Preferred solution to #166380

Changes:

- Moved summary print to bottom of CMakeLists.txt
- Fix the problem 'add_compile_options' should be called before targets defined, so opted for `append_cxx_flag_if_supported` and `append_c_flag_if_supported` ( new ).
- Added extra verbosity so it can be seen when linker script added.

( unfortunately linker script has to be added per-target rather than globally due to ninja/cmake depdendency tracking ).

Also move summary print to bottom of CMakeLists.txt and improve logging
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166407
Approved by: https://github.com/Aidyn-A, https://github.com/atalman
2025-11-10 20:32:08 +00:00
cdc8460f2c Use c7i.2xlarge for H100 build (#167466)
The build system maybe oversized for what is necessary. Reduce the size to optimize costs. The default workflow runner is linux.c7i.2xlarge so we are just removing the runner definition in the workflow so that it uses the default.

Relates to pytorch/test-infra#7175.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167466
Approved by: https://github.com/seemethere
2025-11-10 20:20:54 +00:00
86130aa2ca Fix flaky memory profiler test [2] (#167268)
Fixes #167037

Move the module definition outside of the unit test so when we run the unit test multiple times, the module is not re-compiled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167268
Approved by: https://github.com/angelayi
2025-11-10 19:51:38 +00:00
9491830c79 move subgraph_has_impure_ops from node.is_impure into const_fold to unblock production (#167443)
Summary:
https://github.com/pytorch/pytorch/pull/166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code)
> Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects.

While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before #166609 fix.

Due to production environment constraints, we have to revert https://github.com/pytorch/pytorch/pull/166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops.

NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today.

## This pr
- move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule
- added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future.

Test Plan: run test_fx_const_fold and all tests pass

Differential Revision: D86641994

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167443
Approved by: https://github.com/jfix71
2025-11-10 19:29:54 +00:00
04a85b4c21 [compile-on-one-rank] Step 1: DeviceId (#166680)
Add a "--virtual-local-rank" mode to torchrun. When used instead of passing the
local rank in LOCAL_RANK it uses a LOCAL_RANK of "0" and adjusts
CUDA_VISIBLE_DEVICES to reflect the desired GPU index.

Testing:
(tweaked run_train.sh to use `--log-dir`)
```
export NGPU=8
export CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"
with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4
```

And then comparing ranks:

Without --virtual-local-rank gives a lot of differences like:
```
 [rank#]:        mul_1: "f32[8, 512, 256]" = torch.ops.aten.mul.Tensor(mul, view_9);  mul = None
-[rank#]:        _to_copy_3: "bf16[8, 512, 256]" = torch.ops.aten._to_copy.default(mul_1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0));  mul_1 = None
+[rank#]:        _to_copy_3: "bf16[8, 512, 256]" = torch.ops.aten._to_copy.default(mul_1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=1));  mul_1 = None
 [rank#]:        detach: "f32[8, 512, 1]" = torch.ops.aten.detach.default(rsqrt);  rsqrt = None
```

With --virtual-local-rank makes those differences go away.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166680
Approved by: https://github.com/ezyang
2025-11-10 18:47:31 +00:00
a4437d76f0 Add some labeler rules that used to be in the autolabel bot (#167330)
See https://github.com/pytorch/test-infra/pull/7446 for the paths

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167330
Approved by: https://github.com/huydhn
2025-11-10 18:38:42 +00:00
3ea829a337 Fix torch.cond HOP device in inductor (#167354)
Fixes #166918

The output device may not be on the same device as the predicate device.

```
python test/inductor/test_control_flow.py -k test_output_on_different_device
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167354
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2025-11-10 18:19:38 +00:00
3966b5ad05 [BE] Fix out-of-bounds index_put in test_mps.py (#167444)
Discovered while enabling assertions on out-of-bounds accesses. Otherwise test fails with
```
ERROR: test_sdpa_mask_fp16_L6_S17_NH23_HS121 (__main__.TestSDPA.test_sdpa_mask_fp16_L6_S17_NH23_HS121)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3334, in wrapper
    method(*args, **kwargs)
    ~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/malfet/git/pytorch/pytorch/build/../test/test_mps.py", line 9494, in test_sdpa_mask_fp16_L6_S17_NH23_HS121
    self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/malfet/git/pytorch/pytorch/build/../test/test_mps.py", line 9478, in _test_sdpa_mask
    y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
                                           ~~~~~^^
torch.AcceleratorError: index out of range

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167444
Approved by: https://github.com/Skylion007, https://github.com/manuelcandales
2025-11-10 18:19:28 +00:00
f6a79b2a4a [inductor] Wrap pallas_call in jax.jit (#167441)
My understanding is this is needed for performance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167441
Approved by: https://github.com/oulgen
2025-11-10 17:29:56 +00:00
2fcf41dd8e Add the ruff rule and skip everything for now (#167360)
Part of https://github.com/pytorch/pytorch/issues/164878
We can start narrowing the skips and remove them as PRs keep landing.

This PR is just to setup the scaffolding, fix will be in follow up
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167360
Approved by: https://github.com/janeyx99
2025-11-10 17:10:15 +00:00
31ccd8f13e [AOTI] Fix a mixed-device bug for scatter_add (#167341)
Summary: Fix https://github.com/pytorch/pytorch/issues/166841. AOTI incorrectly generates a call to aoti_torch_cuda_scatter_reduce_two_out while the op should actually run on CPU. Fix by using the correct device when calling _generate_scatter_fallback in the wrapper codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167341
Approved by: https://github.com/yushangdi
2025-11-10 16:59:44 +00:00
32 changed files with 602 additions and 235 deletions

13
.github/labeler.yml vendored
View File

@ -165,3 +165,16 @@
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/mps":
- aten/src/ATen/mps/**
- aten/src/ATen/native/mps/**
- torch/_inductor/codegen/mps.py
- test/test_mps.py
- test/inductor/test_mps_basic.py
"ciflow/h100-symm-mem":
- torch/csrc/distributed/c10d/symm_mem/**
- torch/distributed/_symmetric_memory/**
- test/distributed/**/*mem*
- test/distributed/**/*mem*/**

View File

@ -37,7 +37,6 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "linux.c7i.12xlarge"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '9.0'

View File

@ -41,7 +41,6 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '9.0'

View File

@ -736,6 +736,44 @@ if(NOT DEFINED USE_BLAS)
set(USE_BLAS ON)
endif()
# Prioritized Text Linker Optimization
if(USE_PRIORITIZED_TEXT_FOR_LD)
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
execute_process(
COMMAND ${Python_EXECUTABLE}
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
--filein "${LINKER_SCRIPT_FILE_IN}"
--fout "${LINKER_SCRIPT_FILE_OUT}"
RESULT_VARIABLE _gen_result
OUTPUT_VARIABLE _gen_output
ERROR_VARIABLE _gen_error
)
if(NOT _gen_result EQUAL 0)
message(FATAL_ERROR
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
endif()
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
else()
if(LINUX AND CPU_AARCH64)
message(WARNING [[
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
]])
endif()
endif()
# Build libtorch mobile library, which contains ATen/TH ops and native support
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
if(INTERN_BUILD_MOBILE)
@ -1402,9 +1440,6 @@ if(BUILD_JNI)
add_subdirectory(android/pytorch_android)
endif()
include(cmake/Summary.cmake)
caffe2_print_configuration_summary()
# Parse custom debug info
if(DEFINED USE_CUSTOM_DEBINFO)
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
@ -1444,56 +1479,5 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
DESTINATION "${CMAKE_INSTALL_BINDIR}")
endif()
if(USE_PRIORITIZED_TEXT_FOR_LD)
add_compile_options(
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
)
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
add_custom_command(
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
COMMENT "Generating prioritized text linker files"
VERBATIM
)
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
if(BUILD_PYTHON)
set(LINKER_OPT_TARGETS torch_python)
endif()
if(NOT BUILD_LIBTORCHLESS)
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
if(USE_CUDA)
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
endif()
if(USE_XPU)
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
endif()
if(USE_ROCM)
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
endif()
endif()
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
if(TARGET ${tgt})
add_dependencies("${tgt}" generate_linker_script)
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
else()
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
endif()
endforeach()
else()
if(LINUX AND CPU_AARCH64)
message(WARNING [[
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
]])
endif()
endif()
include(cmake/Summary.cmake)
caffe2_print_configuration_summary()

View File

@ -478,6 +478,7 @@ function(torch_update_find_cuda_flags)
endfunction()
include(CheckCXXCompilerFlag)
include(CheckCCompilerFlag)
include(CheckLinkerFlag)
##############################################################################
@ -501,6 +502,24 @@ function(append_cxx_flag_if_supported flag outputvar)
endif()
endfunction()
function(append_c_flag_if_supported flag outputvar)
string(TOUPPER "HAS${flag}" _FLAG_NAME)
string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
# GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX.
if(CMAKE_C_COMPILER_ID STREQUAL "GNU")
string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}")
else()
set(new_flag "${flag}")
endif()
check_c_compiler_flag("${new_flag}" ${_FLAG_NAME})
if(${_FLAG_NAME})
string(APPEND ${outputvar} " ${flag}")
set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
endif()
endfunction()
function(target_compile_options_if_supported target flag)
set(_compile_options "")
append_cxx_flag_if_supported("${flag}" _compile_options)

View File

@ -260,6 +260,7 @@ select = [
"TRY401", # verbose-log-message
"UP",
"YTT",
"S101",
]
[tool.ruff.lint.pyupgrade]
@ -339,6 +340,39 @@ keep-runtime-typing = true
"tools/linter/**" = [
"LOG015" # please fix
]
"benchmarks/**" = [
"S101"
]
"test/**" = [
"S101"
]
"torchgen/**" = [
"S101"
]
"torch/**" = [
"S101"
]
"tools/**" = [
"S101"
]
"setup.py" = [
"S101"
]
"functorch/**" = [
"S101"
]
"docs/**" = [
"S101"
]
"android/**" = [
"S101"
]
".github/**" = [
"S101"
]
".ci/**" = [
"S101"
]
[tool.codespell]
ignore-words = "tools/linter/dictionary.txt"

View File

@ -0,0 +1,44 @@
# Owner(s): ["oncall: r2p"]
# This is a helper script for
# test_run.py::ElasticLaunchTest::test_virtual_local_rank. It prints out the
# generated inductor output for a simple function.
import os
from unittest.mock import patch
import torch
import torch.distributed as dist
from torch._inductor import codecache
@torch.compile
def myfn(x: torch.Tensor) -> torch.Tensor:
return x + x
dist.init_process_group(backend="nccl")
local_rank = int(os.environ.get("LOCAL_RANK", "cuda:0"))
torch.cuda.set_device(local_rank)
def print_output_code(original_fn):
def wrapper(msg, *args, **kwargs):
# Check if this is the "Output code:" message
if args and "Output code:" in msg:
print(args[0])
return wrapper
x = torch.rand(2, 2, device="cuda")
with patch.object(
codecache.output_code_log,
"debug",
side_effect=print_output_code(codecache.output_code_log.debug),
):
y = myfn(x)
dist.destroy_process_group()

View File

@ -16,7 +16,7 @@ import sys
import tempfile
import uuid
from contextlib import closing, redirect_stderr, redirect_stdout
from unittest import mock
from unittest import mock, skipIf
from unittest.mock import MagicMock, Mock, patch
import torch.distributed.run as launch
@ -28,6 +28,7 @@ from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_CUDA,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
@ -677,6 +678,96 @@ class ElasticLaunchTest(TestCase):
for i in range(nproc_per_node):
self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue())
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@skipIf(not TEST_CUDA, "requires CUDA")
def test_virtual_local_rank(self):
"""
Test that virtual-local-rank ensures consistent device IDs across ranks.
Without it, ranks may compile to different devices, leading to different code.
"""
run_id = str(uuid.uuid4().int)
nnodes = 1
nproc_per_node = 2
# Helper function to run and capture output
def run_test(use_virtual_local_rank):
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
"--redirect=3",
"--tee=3",
]
if use_virtual_local_rank:
args.append("--virtual-local-rank")
args.append(path("script_deviceid.py"))
captured_out = io.StringIO()
captured_err = io.StringIO()
with redirect_stdout(captured_out), redirect_stderr(captured_err):
launch.main(args)
return captured_out.getvalue()
def split_ranks(output):
default0 = []
default1 = []
for line in output.splitlines():
if "cuda:" not in line:
continue
if line.startswith("[default0]:"):
default0.append(line[11:])
elif line.startswith("[default1]:"):
default1.append(line[11:])
return default0, default1
# First, run WITHOUT virtual-local-rank - outputs should differ
output = run_test(use_virtual_local_rank=False)
rank0, rank1 = split_ranks(output)
# Verify we actually captured compiled code from both ranks
self.assertGreater(
len(rank0), 0, "Expected to capture compiled code from rank 0"
)
self.assertGreater(
len(rank1), 0, "Expected to capture compiled code from rank 1"
)
# Without virtual-local-rank, the ranks should have DIFFERENT compiled code
# because they see different device IDs (cuda:0 vs cuda:1)
self.assertNotEqual(
rank0,
rank1,
"Expected different compiled code without --virtual-local-rank",
)
# Now run WITH virtual-local-rank - outputs should be identical
output = run_test(use_virtual_local_rank=True)
rank0, rank1 = split_ranks(output)
# Verify we actually captured compiled code from both ranks
self.assertGreater(
len(rank0),
0,
"Expected to capture compiled code from rank 0 with --virtual-local-rank",
)
self.assertGreater(
len(rank1),
0,
"Expected to capture compiled code from rank 1 with --virtual-local-rank",
)
# With virtual-local-rank, both ranks should have IDENTICAL compiled code
# because they both see cuda:0 during compilation
self.assertEqual(
rank0, rank1, "Expected identical compiled code with --virtual-local-rank"
)
if __name__ == "__main__":
run_tests()

View File

@ -7522,6 +7522,38 @@ class AOTInductorTestsTemplate:
eager_outputs = model(*example_inputs)
torch.testing.assert_close(eager_outputs, compiled_outputs)
@requires_gpu
def test_mixed_device_1(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("Mixed-device test requires GPU")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# Buffers are on CPU
self.register_buffer(
"index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64)
)
self.register_buffer(
"src", torch.ones(4, device="cpu", dtype=torch.int64)
)
def forward(self, matrix, vector):
# Inputs are on CUDA
# 1. Operation on CPU tensors
z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64)
scatter_result = z.scatter_add(0, self.index, self.src)
# 2. Move result to CUDA and continue on CUDA
v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE)
return torch.matmul(matrix, v)
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, device=self.device),
)
self.check_model(Model(), example_inputs, move_model_to_device=False)
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

View File

@ -218,6 +218,7 @@ def check_model(
dynamic_shapes=None,
atol=None,
rtol=None,
move_model_to_device=True,
):
with (
torch.no_grad(),
@ -229,7 +230,7 @@ def check_model(
),
):
torch.manual_seed(0)
if not isinstance(model, types.FunctionType):
if not isinstance(model, types.FunctionType) and move_model_to_device:
model = model.to(self.device)
# For non mixed device inputs with default "cpu",set the device manually.

View File

@ -20,9 +20,11 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.testing._internal.triton_utils import requires_gpu
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1, device=None):
result = []
device = inputs[0].device
if len(inputs) != 0:
device = inputs[0].device
assert device
# iterate over the cartesian product of predicate values
for values in itertools.product(*([possible_values] * num_to_prepend)):
prepended = [torch.tensor(v, device=device) for v in values]
@ -30,8 +32,8 @@ def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
return result
def prepend_predicates(inputs, num_predicates=1):
return _prepend_product_of_values(inputs, [False, True], num_predicates)
def prepend_predicates(inputs, num_predicates=1, device=None):
return _prepend_product_of_values(inputs, [False, True], num_predicates, device)
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
@ -308,7 +310,9 @@ class CondTests(TestCase):
torch._dynamo.mark_dynamic(inp, 0)
for inputs in input_sets:
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
for inputs_with_predicates in prepend_predicates(
inputs, num_predicates, device=device
):
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
result = model(*inputs_with_predicates)
result_compiled = compiled_model(*inputs_with_predicates)
@ -768,6 +772,26 @@ class CondTests(TestCase):
dynamic=dynamic,
)
@requires_gpu
def test_output_on_different_device(self):
class FactoryBranches(torch.nn.Module):
def forward(self, pred):
tensor = torch.cond(
pred,
lambda: torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).to(
GPU_TYPE
),
lambda: torch.zeros(5, dtype=torch.float32).to(GPU_TYPE),
)
return tensor + 1
self._run_test(
model=FactoryBranches(),
inputs=(),
device="cpu", # device for predicate
dynamic=True,
)
class WhileLoopModels:
class Simple(torch.nn.Module):

View File

@ -726,8 +726,7 @@ class CPUReproTests(TestCase):
seq_len,
)
@parametrize(
"unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len",
_test_lstm_packed_change_input_sizes_cpu_params = list(
itertools.product(
*[
[False],
@ -741,7 +740,12 @@ class CPUReproTests(TestCase):
[2],
[3],
]
),
)
)
@parametrize(
"unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len",
_test_lstm_packed_change_input_sizes_cpu_params,
)
def test_lstm_packed_change_input_sizes_cpu(
self,

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: pt2"]
import functools
import re
import sys
import unittest
@ -230,6 +231,33 @@ class PallasTestsMixin:
self.assertIn("import jax.numpy as jnp", code)
self.assertIn("from jax.experimental import pallas as pl", code)
def test_jax_jit_wrapper_is_emitted(self):
"""Ensure generated Pallas code wraps pl.pallas_call in jax.jit."""
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
@torch.compile(backend="inductor", options={key: "pallas"})
def pallas_fn(a, b):
return a + b
_, (code,) = run_and_get_code(
pallas_fn,
torch.randn(32, device=self.DEVICE),
torch.randn(32, device=self.DEVICE),
)
kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code)
self.assertIsNotNone(kernel_match)
kernel_name = kernel_match.group(1)
wrapper_name = f"{kernel_name}_jit_wrapper"
self.assertIn(wrapper_name, code)
start = code.index(f"def {wrapper_name}")
end = code.index(f"def {kernel_name}_main", start)
wrapper_block = code[start:end]
self.assertIn("jax.jit", code)
self.assertNotIn("torch.", wrapper_block)
def test_2d_tensor(self):
"""Test with 2D tensors (though current implementation flattens)."""

View File

@ -7455,6 +7455,34 @@ class TestCudaDeviceParametrized(TestCase):
class TestFXMemoryProfiler(TestCase):
"""Tests for memory profiler augmentation with original stack traces."""
class MLPModule(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
a = self.net1(x)
b = self.relu(a)
c = self.net2(b)
return c
class MLPModule2(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
d = self.net1(x)
e = self.relu(d)
f = self.net2(e)
return f
def collect_frames(
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
):
@ -7490,99 +7518,64 @@ class TestFXMemoryProfiler(TestCase):
def test_fx_memory_profiler_augmentation(self):
"""Test that memory snapshots are augmented with FX debug information."""
# Create a simple model
class MLPModule(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
def forward(self, x):
a = self.net1(x)
b = self.relu(a)
c = self.net2(b)
return c
device = "cuda"
mod = MLPModule(device)
with tempfile.TemporaryDirectory() as tmpdir:
# reset cache to start fresh
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(
augment_with_fx_traces=True
)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
torch.cuda.empty_cache()
mod = self.MLPModule(device)
# reset cache to start fresh
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(augment_with_fx_traces=True)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
torch.cuda.empty_cache()
fx_frames = self.collect_frames(augmented_snapshot)
self.assertGreater(len(fx_frames), 2)
fx_frames = self.collect_frames(augmented_snapshot)
self.assertGreater(len(fx_frames), 2)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
# Test that when we have two graphs with the same src_code, they're not hashed
# to the same metadata
class MLPModule2(nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=True, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=True, device=device)
mod = self.MLPModule2(device)
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(augment_with_fx_traces=True)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
def forward(self, x):
d = self.net1(x)
e = self.relu(d)
f = self.net2(e)
return f
# avoid collecting segments from previous run for unit test purpose
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
self.assertGreater(len(fx_frames), 0)
mod = MLPModule2(device)
with tempfile.TemporaryDirectory() as tmpdir:
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
augmented_snapshot = torch.cuda.memory._snapshot(
augment_with_fx_traces=True
)
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
# avoid collecting segments from previous run for unit test purpose
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
self.assertGreater(len(fx_frames), 0)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name
self.assertIn("fx_node_op", frame)
self.assertIn("fx_node_name", frame)
self.assertIn("fx_node_target", frame)
self.assertIn("fx_original_trace", frame)
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
fx_node_name = frame["fx_node_name"]
if fx_node_name == "addmm":
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
elif fx_node_name == "addmm_1":
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
elif fx_node_name == "relu":
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
instantiate_parametrized_tests(TestCuda)

View File

@ -9465,7 +9465,7 @@ class TestSDPA(TestCaseMPS):
torch.manual_seed(1729)
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
i = 42
i = 42 if S > 42 else S // 2
q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")

View File

@ -4506,7 +4506,7 @@ class TestSerialization(TestCase, SerializationMixin):
exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError
with self.assertRaisesRegex(
exc,
"Can't (get|pickle) local object (<function |')WeakValueDictionary.__init__.<locals>.remove"
r"Can't (get|pickle) local object (<function |')WeakValueDictionary\.__init__\.<locals>\.remove"
):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)

View File

@ -974,6 +974,41 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
subgraph.recompile()
def _assign_new_node_names(
gm: torch.fx.GraphModule,
name_map: dict[str, str],
custom_meta: dict[str, Any],
) -> None:
"""
Assign new names to all nodes, in the graph module, from name map.
"""
for node in gm.graph.nodes:
if node.op == "placeholder":
assert node.name in name_map
node.name = node.target = name_map[node.name]
if node.name in custom_meta:
if node.meta.get("custom") is None:
node.meta["custom"] = {}
else:
# Assert if any existing key has different value
for k, v in node.meta["custom"].items():
if (
k in custom_meta[node.name]
and v != custom_meta[node.name][k]
):
raise AssertionError(
f"Mismatch in custom metadata for key {k}. Value in "
f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}."
)
node.meta["custom"].update(custom_meta[node.name])
# if the constant obj is an input, we also need to update meta["val"]
# because this is created before the placeholder naming pass
if isinstance(node.meta["val"], CustomObjArgument):
node.meta["val"].name = node.name
elif node.name in name_map:
node.name = name_map[node.name]
def placeholder_naming_pass(
gm: torch.fx.GraphModule,
export_graph_signature: "ExportGraphSignature",
@ -1091,31 +1126,7 @@ def placeholder_naming_pass(
)
# assign new node names
for node in gm.graph.nodes:
if node.op == "placeholder":
assert node.name in name_map
node.name = node.target = name_map[node.name]
if node.name in custom_meta:
if node.meta.get("custom") is None:
node.meta["custom"] = {}
else:
# Assert if any existing key has different value
for k, v in node.meta["custom"].items():
if (
k in custom_meta[node.name]
and v != custom_meta[node.name][k]
):
raise AssertionError(
f"Mismatch in custom metadata for key {k}. Value in "
f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}."
)
node.meta["custom"].update(custom_meta[node.name])
# if the constant obj is an input, we also need to update meta["val"]
# because this is created before the placeholder naming pass
if isinstance(node.meta["val"], CustomObjArgument):
node.meta["val"].name = node.name
elif node.name in name_map:
node.name = name_map[node.name]
_assign_new_node_names(gm, name_map, custom_meta)
# propagate names to higher order op subgraphs
_name_hoo_subgraph_placeholders(gm)

View File

@ -624,7 +624,7 @@ class FxGraphCachePickler(pickle.Pickler):
try:
self.dump(obj)
return self._stream.getvalue()
except (TypeError, AttributeError) as e:
except (TypeError, AttributeError, pickle.PicklingError) as e:
# Some configs options may not pickle.
log.warning("Failed to pickle cache key", exc_info=True)
raise BypassFxGraphCache("Failed to pickle cache key") from e

View File

@ -221,7 +221,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""
)
self.add_device_include(self.device)
for device in V.graph.device_types:
if device != "meta":
self.add_device_include(device)
if V.graph.aot_mode:
if config.aot_inductor.dynamic_linkage:
@ -1423,11 +1425,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
src_is_tensor,
reduce,
kwargs,
device,
):
reduce = self._get_scatter_reduce_enum(reduce)
# call the ABI shim function instead of the ATen one
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
self.add_device_include(device)
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
inputs_wrapped = [str(x) for x in inputs]

View File

@ -708,11 +708,14 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
src_is_tensor,
reduce,
kwargs,
device,
):
reduce = self._get_scatter_reduce_enum(reduce)
# call the ABI shim function instead of the ATen one
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
self.add_device_include(device)
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()

View File

@ -287,6 +287,7 @@ class PallasKernel(SIMDKernel):
code = IndentedBuffer()
code.splice(
"""
import functools
import torch
import jax
import jax.numpy as jnp
@ -301,6 +302,9 @@ class PallasKernel(SIMDKernel):
kernel_params = [a.name for a in arg_defs]
kernel_name = name or "<KERNEL_NAME>"
interpret_literal = (
"True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
)
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
with code.indent():
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
@ -309,16 +313,22 @@ class PallasKernel(SIMDKernel):
for line in self.stores._lines:
code.writeline(str(line))
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
with code.indent():
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
code.writeline("return pl.pallas_call(")
code.writeline(f" {kernel_name}_kernel,")
code.writeline(" out_shape=out_spec,")
code.writeline(f" interpret={interpret_literal},")
code.writeline(" grid=(1,),")
code.writeline(")(*kernel_refs)")
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
main_name = f"{kernel_name}_main"
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
with code.indent():
# Determine interpret statically based on codegen device
interpret_literal = (
"True"
if V.graph.get_current_device_or_throw().type == "cpu"
else "False"
)
# Identify inputs (in_ptr*) and output (out_ptr*)
input_params = [
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
@ -337,9 +347,9 @@ class PallasKernel(SIMDKernel):
for inp in input_params:
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
# Get output spec from PyTorch tensor
code.writeline("# Prepare output spec from PyTorch tensor")
code.writeline("# Map PyTorch dtype to JAX dtype string")
# Get output metadata from PyTorch tensor
code.writeline("# Prepare output metadata from PyTorch tensor")
code.writeline("# Map PyTorch dtype to JAX dtype")
code.writeline("_torch_dtype_to_jax = {")
code.writeline(
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
@ -349,21 +359,14 @@ class PallasKernel(SIMDKernel):
)
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
code.writeline("}")
code.writeline(
f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])"
)
code.writeline(f"out_shape = tuple({output_param}.shape)")
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
# Call pallas
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
code.writeline("compiled = pl.pallas_call(")
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
code.writeline(" out_shape=out_spec,")
code.writeline(f" interpret={interpret_literal},")
code.writeline(" grid=(1,),")
code.writeline(")")
jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params])
code.writeline(f"res = compiled({jax_input_args})")
call_args = ["out_shape", "out_dtype"] + [
f"{inp}_jax" for inp in input_params
]
call_arg_str = ", ".join(call_args)
code.writeline(f"res = {jit_wrapper_name}({call_arg_str})")
# Copy result back
code.writeline("# Copy result back into the provided torch output tensor")

View File

@ -971,6 +971,7 @@ class ScatterFallbackLine(WrapperLine):
else:
(x, index) = (t.codegen_reference() for t in node.inputs)
src = node.constant_args[1]
device = d.type if (d := node.get_device()) else V.graph.device_type
self.wrapper._generate_scatter_fallback(
x,
[x, node.constant_args[0], index, src],
@ -979,6 +980,7 @@ class ScatterFallbackLine(WrapperLine):
node.src_is_tensor,
node.kwargs["reduce"],
node.codegen_kwargs(),
device,
)
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
@ -1632,6 +1634,7 @@ class PythonWrapperCodegen(CodeGen):
src_is_tensor,
reduce,
kwargs,
device,
):
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
if python_kernel_name.startswith("aten.scatter_reduce"):

View File

@ -468,6 +468,8 @@ class _SerializedFxCompile(FxCompile):
fake_mode = _current_fake_mode()
fake_tensor_mode = _FakeTensorModeSerializer(fake_mode)
from pickle import PicklingError
try:
input = _WireProtocolInput(
gm,
@ -483,7 +485,7 @@ class _SerializedFxCompile(FxCompile):
fake_tensor_mode,
).serialize()
return (input, constants)
except (AttributeError, BypassFxGraphCache):
except (AttributeError, BypassFxGraphCache, PicklingError):
# For example: AttributeError: Can't pickle local object
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'

View File

@ -8845,7 +8845,9 @@ class Conditional(ExternKernel):
outputs = [
MultiOutput(
FixedLayout(
device=device,
device=output.get_device()
if output.get_device() is not None
else device, # type: ignore[arg-type]
dtype=output.get_dtype(),
size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
stride=[

View File

@ -48,7 +48,8 @@ logger = get_logger(__name__)
@dataclass
class WorkerSpec:
"""Blueprint information about a particular type of worker.
"""
Blueprint information about a particular type of worker.
For a given role, there must only exist a single worker spec.
Worker spec is expected to be homogeneous across all nodes (machine),
@ -79,6 +80,10 @@ class WorkerSpec:
that match _any_ of the filter strings.
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
that match _any_ of the filter strings.
virtual_local_rank: Enable virtual local rank mode for workers (defaults to False).
When enabled, LOCAL_RANK is set to 0 for all workers and
CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its
assigned GPU at device index 0.
"""
role: str
@ -97,6 +102,7 @@ class WorkerSpec:
numa_options: Optional[NumaOptions] = None
duplicate_stdout_filters: Optional[list[str]] = None
duplicate_stderr_filters: Optional[list[str]] = None
virtual_local_rank: bool = False
def __post_init__(self):
assert self.local_world_size > 0

View File

@ -303,7 +303,6 @@ class LocalElasticAgent(SimpleElasticAgent):
for worker in worker_group.workers:
local_rank = worker.local_rank
worker_env = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
@ -322,6 +321,7 @@ class LocalElasticAgent(SimpleElasticAgent):
"TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
),
}
self._set_local_rank_env(worker_env, local_rank, spec)
if "OMP_NUM_THREADS" in os.environ:
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
@ -362,6 +362,46 @@ class LocalElasticAgent(SimpleElasticAgent):
return self._pcontext.pids()
def _set_local_rank_env(
self, worker_env: dict[str, str | None], local_rank: int, spec: WorkerSpec
) -> None:
# Set CUDA_VISIBLE_DEVICES and LOCAL_RANK based on virtual_local_rank mode.
# Virtual mode: Each worker sees only its assigned GPU as device 0, LOCAL_RANK=0
# Traditional mode: Workers see all GPUs, LOCAL_RANK matches actual local rank
if spec.virtual_local_rank:
# Set LOCAL_RANK=0 and use CUDA_VISIBLE_DEVICES to control the actual GPU access.
worker_env["LOCAL_RANK"] = "0"
# Map local_rank through existing CUDA_VISIBLE_DEVICES
# HIP uses CUDA_VISIBLE_DEVICES as a compatibility hack:
# https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#cuda-visible-devices
parent_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if parent_visible_devices is not None:
# Parse comma-separated list of GPU IDs
available_gpus = parent_visible_devices.split(",")
if local_rank >= len(available_gpus):
raise ValueError(
f"local_rank {local_rank} exceeds available GPUs in "
f"CUDA_VISIBLE_DEVICES={parent_visible_devices}"
)
visible_gpu = available_gpus[local_rank].strip()
else:
# No restriction, use local_rank directly
visible_gpu = str(local_rank)
worker_env["CUDA_VISIBLE_DEVICES"] = visible_gpu
return
# In traditional mode, don't override CUDA_VISIBLE_DEVICES
# (inherit from parent environment)
worker_env["LOCAL_RANK"] = str(local_rank)
if "CUDA_VISIBLE_DEVICES" in os.environ:
worker_env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
if self._worker_watchdog is not None:
self._worker_watchdog.stop()

View File

@ -75,6 +75,10 @@ class LaunchConfig:
that match _any_ of the filter strings.
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
that match _any_ of the filter strings.
virtual_local_rank: Enable virtual local rank mode for workers (defaults to False).
When enabled, LOCAL_RANK is set to 0 for all workers and
CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its
assigned GPU at device index 0.
.. note::
@ -104,6 +108,7 @@ class LaunchConfig:
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
duplicate_stdout_filters: Optional[list[str]] = None
duplicate_stderr_filters: Optional[list[str]] = None
virtual_local_rank: bool = False
def __post_init__(self):
default_timeout = 900
@ -288,6 +293,7 @@ def launch_agent(
numa_options=config.numa_options,
duplicate_stdout_filters=config.duplicate_stdout_filters,
duplicate_stderr_filters=config.duplicate_stderr_filters,
virtual_local_rank=config.virtual_local_rank,
)
agent = LocalElasticAgent(

View File

@ -688,6 +688,15 @@ def get_args_parser() -> ArgumentParser:
"Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).",
)
parser.add_argument(
"--virtual-local-rank",
"--virtual_local_rank",
action=check_env,
help="Enable virtual local rank mode for workers. When enabled, LOCAL_RANK is set to 0 "
"for all workers and CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its "
"assigned GPU at device index 0.",
)
#
# Positional arguments.
#
@ -907,6 +916,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
signals_to_handle=args.signals_to_handle,
duplicate_stdout_filters=args.duplicate_stdout_filters,
duplicate_stderr_filters=args.duplicate_stderr_filters,
virtual_local_rank=args.virtual_local_rank,
)
with_python = not args.no_python

View File

@ -256,7 +256,7 @@ def bmm_strategy(op_schema: OpSchema) -> OpStrategy:
@register_op_strategy(aten.baddbmm.default)
def baddmm_strategy(op_schema: OpSchema) -> OpStrategy:
def baddbmm_strategy(op_schema: OpSchema) -> OpStrategy:
mesh = op_schema.get_mesh_from_args()
return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema)

View File

@ -177,6 +177,26 @@ def split_const_subgraphs(
else:
mod_traced = module
def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
"""
Return True if a GraphModule type subgraph contains any impure op, else False.
"""
assert isinstance(module, torch.fx.GraphModule), (
"caller should only pass GraphModule to subgraph_has_impure_ops check"
)
for node in module.graph.nodes:
if node.op == "call_function" and node.is_impure():
return True
if (
# pyrefly: ignore [invalid-argument]
node.op == "call_module"
# pyrefly: ignore [not-callable]
and (submodule := module.get_submodule(node.target))
and isinstance(submodule, torch.fx.GraphModule)
):
return _subgraph_has_impure_ops(submodule)
return False
# Build up a list of const_nodes, defined as nodes that are themselves
# get_attrs, or have all get_attr or other constant node inputs.
const_nodes: set[torch.fx.Node] = set()
@ -206,6 +226,17 @@ def split_const_subgraphs(
if isinstance(node.kwargs.get("fill_value", None), sympy.Expr):
continue
# Skip folding submodules that have impure ops
if (
# pyrefly: ignore [invalid-argument]
node.op == "call_module"
# pyrefly: ignore [not-callable]
and (target_mod := mod_traced.get_submodule(node.target))
and isinstance(target_mod, torch.fx.GraphModule)
and _subgraph_has_impure_ops(target_mod)
):
continue
# Must be a constant foldable node at this point.
const_nodes.add(node)
if node.op != "get_attr":

View File

@ -754,26 +754,6 @@ class Node(_NodeBase):
return self.target in _side_effectful_functions
def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
"""
Return True if a GraphModule type subgraph contains any impure op, else False.
"""
assert isinstance(module, torch.fx.GraphModule), (
"caller should only pass GraphModule to subgraph_has_impure_ops check"
)
for node in module.graph.nodes:
if node.op == "call_function" and node.is_impure(impure_random):
return True
if (
# pyrefly: ignore [invalid-argument]
node.op == "call_module"
# pyrefly: ignore [not-callable]
and (submodule := module.get_submodule(node.target))
and isinstance(submodule, torch.fx.GraphModule)
):
return subgraph_has_impure_ops(submodule)
return False
# Check if an impure module.
if self.op == "call_module":
assert self.graph.owning_module is not None, (
@ -783,10 +763,11 @@ class Node(_NodeBase):
assert target_mod is not None, (
f"Did not find expected submodule target {self.target}"
)
if isinstance(target_mod, torch.fx.GraphModule):
return subgraph_has_impure_ops(target_mod)
else:
return getattr(target_mod, "_is_impure", False)
# NOTE: here we can end up considering GraphModule submodules pure,
# even if they contain impure ops. It may not be safe to change
# because this function is used by graph.eliminate_dead_code,
# and some users depend on current elimination behavior.
return getattr(target_mod, "_is_impure", False)
return False

View File

@ -1711,7 +1711,7 @@ class MultiProcContinuousTest(TestCase):
@classmethod
def _init_pg(cls, rank, world_size, rdvz_file):
assert rdvz_file is not None
# rank should be local_rank for tests running on <= 8gpus which is how all these tests are designed
# rank should be local_rank for tests running on <= 8 gpus which is how all these tests are designed
# and we expect LOCAL_RANK set by torchrun. Setting it lets init_device_mesh set the device without
# issuing a warning
os.environ["LOCAL_RANK"] = str(rank)