mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 02:35:10 +08:00
Compare commits
15 Commits
documentat
...
trunk/ad7d
| Author | SHA1 | Date | |
|---|---|---|---|
| ad7db3617e | |||
| 5320ca3725 | |||
| 3e4faca130 | |||
| 0c2f206ded | |||
| 6cf21fa331 | |||
| cdc8460f2c | |||
| 86130aa2ca | |||
| 9491830c79 | |||
| 04a85b4c21 | |||
| a4437d76f0 | |||
| 3ea829a337 | |||
| 3966b5ad05 | |||
| f6a79b2a4a | |||
| 2fcf41dd8e | |||
| 31ccd8f13e |
13
.github/labeler.yml
vendored
13
.github/labeler.yml
vendored
@ -165,3 +165,16 @@
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
"ciflow/mps":
|
||||
- aten/src/ATen/mps/**
|
||||
- aten/src/ATen/native/mps/**
|
||||
- torch/_inductor/codegen/mps.py
|
||||
- test/test_mps.py
|
||||
- test/inductor/test_mps_basic.py
|
||||
|
||||
"ciflow/h100-symm-mem":
|
||||
- torch/csrc/distributed/c10d/symm_mem/**
|
||||
- torch/distributed/_symmetric_memory/**
|
||||
- test/distributed/**/*mem*
|
||||
- test/distributed/**/*mem*/**
|
||||
|
||||
1
.github/workflows/h100-distributed.yml
vendored
1
.github/workflows/h100-distributed.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.c7i.12xlarge"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
1
.github/workflows/test-h100.yml
vendored
1
.github/workflows/test-h100.yml
vendored
@ -41,7 +41,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
@ -736,6 +736,44 @@ if(NOT DEFINED USE_BLAS)
|
||||
set(USE_BLAS ON)
|
||||
endif()
|
||||
|
||||
# Prioritized Text Linker Optimization
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
|
||||
--filein "${LINKER_SCRIPT_FILE_IN}"
|
||||
--fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
RESULT_VARIABLE _gen_result
|
||||
OUTPUT_VARIABLE _gen_output
|
||||
ERROR_VARIABLE _gen_error
|
||||
)
|
||||
|
||||
if(NOT _gen_result EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
|
||||
endif()
|
||||
|
||||
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
|
||||
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
|
||||
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
|
||||
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
|
||||
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build libtorch mobile library, which contains ATen/TH ops and native support
|
||||
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
|
||||
if(INTERN_BUILD_MOBILE)
|
||||
@ -1402,9 +1440,6 @@ if(BUILD_JNI)
|
||||
add_subdirectory(android/pytorch_android)
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
# Parse custom debug info
|
||||
if(DEFINED USE_CUSTOM_DEBINFO)
|
||||
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
|
||||
@ -1444,56 +1479,5 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
|
||||
DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||
endif()
|
||||
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
add_compile_options(
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
|
||||
)
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
|
||||
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
|
||||
COMMENT "Generating prioritized text linker files"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
if(BUILD_PYTHON)
|
||||
set(LINKER_OPT_TARGETS torch_python)
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_LIBTORCHLESS)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
|
||||
if(USE_CUDA)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
|
||||
endif()
|
||||
if(USE_XPU)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
|
||||
endif()
|
||||
if(USE_ROCM)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
|
||||
if(TARGET ${tgt})
|
||||
add_dependencies("${tgt}" generate_linker_script)
|
||||
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
|
||||
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
else()
|
||||
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
@ -478,6 +478,7 @@ function(torch_update_find_cuda_flags)
|
||||
endfunction()
|
||||
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CheckCCompilerFlag)
|
||||
include(CheckLinkerFlag)
|
||||
|
||||
##############################################################################
|
||||
@ -501,6 +502,24 @@ function(append_cxx_flag_if_supported flag outputvar)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(append_c_flag_if_supported flag outputvar)
|
||||
string(TOUPPER "HAS${flag}" _FLAG_NAME)
|
||||
string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
|
||||
|
||||
# GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX.
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "GNU")
|
||||
string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}")
|
||||
else()
|
||||
set(new_flag "${flag}")
|
||||
endif()
|
||||
|
||||
check_c_compiler_flag("${new_flag}" ${_FLAG_NAME})
|
||||
if(${_FLAG_NAME})
|
||||
string(APPEND ${outputvar} " ${flag}")
|
||||
set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(target_compile_options_if_supported target flag)
|
||||
set(_compile_options "")
|
||||
append_cxx_flag_if_supported("${flag}" _compile_options)
|
||||
|
||||
@ -260,6 +260,7 @@ select = [
|
||||
"TRY401", # verbose-log-message
|
||||
"UP",
|
||||
"YTT",
|
||||
"S101",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pyupgrade]
|
||||
@ -339,6 +340,39 @@ keep-runtime-typing = true
|
||||
"tools/linter/**" = [
|
||||
"LOG015" # please fix
|
||||
]
|
||||
"benchmarks/**" = [
|
||||
"S101"
|
||||
]
|
||||
"test/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torchgen/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"tools/**" = [
|
||||
"S101"
|
||||
]
|
||||
"setup.py" = [
|
||||
"S101"
|
||||
]
|
||||
"functorch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"docs/**" = [
|
||||
"S101"
|
||||
]
|
||||
"android/**" = [
|
||||
"S101"
|
||||
]
|
||||
".github/**" = [
|
||||
"S101"
|
||||
]
|
||||
".ci/**" = [
|
||||
"S101"
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
44
test/distributed/launcher/script_deviceid.py
Normal file
44
test/distributed/launcher/script_deviceid.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
# This is a helper script for
|
||||
# test_run.py::ElasticLaunchTest::test_virtual_local_rank. It prints out the
|
||||
# generated inductor output for a simple function.
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._inductor import codecache
|
||||
|
||||
|
||||
@torch.compile
|
||||
def myfn(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "cuda:0"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
|
||||
def print_output_code(original_fn):
|
||||
def wrapper(msg, *args, **kwargs):
|
||||
# Check if this is the "Output code:" message
|
||||
if args and "Output code:" in msg:
|
||||
print(args[0])
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
x = torch.rand(2, 2, device="cuda")
|
||||
|
||||
with patch.object(
|
||||
codecache.output_code_log,
|
||||
"debug",
|
||||
side_effect=print_output_code(codecache.output_code_log.debug),
|
||||
):
|
||||
y = myfn(x)
|
||||
|
||||
dist.destroy_process_group()
|
||||
@ -16,7 +16,7 @@ import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from contextlib import closing, redirect_stderr, redirect_stdout
|
||||
from unittest import mock
|
||||
from unittest import mock, skipIf
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import torch.distributed.run as launch
|
||||
@ -28,6 +28,7 @@ from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
@ -677,6 +678,96 @@ class ElasticLaunchTest(TestCase):
|
||||
for i in range(nproc_per_node):
|
||||
self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue())
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
@skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_virtual_local_rank(self):
|
||||
"""
|
||||
Test that virtual-local-rank ensures consistent device IDs across ranks.
|
||||
Without it, ranks may compile to different devices, leading to different code.
|
||||
"""
|
||||
run_id = str(uuid.uuid4().int)
|
||||
nnodes = 1
|
||||
nproc_per_node = 2
|
||||
|
||||
# Helper function to run and capture output
|
||||
def run_test(use_virtual_local_rank):
|
||||
args = [
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
"--redirect=3",
|
||||
"--tee=3",
|
||||
]
|
||||
if use_virtual_local_rank:
|
||||
args.append("--virtual-local-rank")
|
||||
|
||||
args.append(path("script_deviceid.py"))
|
||||
|
||||
captured_out = io.StringIO()
|
||||
captured_err = io.StringIO()
|
||||
with redirect_stdout(captured_out), redirect_stderr(captured_err):
|
||||
launch.main(args)
|
||||
|
||||
return captured_out.getvalue()
|
||||
|
||||
def split_ranks(output):
|
||||
default0 = []
|
||||
default1 = []
|
||||
for line in output.splitlines():
|
||||
if "cuda:" not in line:
|
||||
continue
|
||||
if line.startswith("[default0]:"):
|
||||
default0.append(line[11:])
|
||||
elif line.startswith("[default1]:"):
|
||||
default1.append(line[11:])
|
||||
return default0, default1
|
||||
|
||||
# First, run WITHOUT virtual-local-rank - outputs should differ
|
||||
output = run_test(use_virtual_local_rank=False)
|
||||
rank0, rank1 = split_ranks(output)
|
||||
|
||||
# Verify we actually captured compiled code from both ranks
|
||||
self.assertGreater(
|
||||
len(rank0), 0, "Expected to capture compiled code from rank 0"
|
||||
)
|
||||
self.assertGreater(
|
||||
len(rank1), 0, "Expected to capture compiled code from rank 1"
|
||||
)
|
||||
|
||||
# Without virtual-local-rank, the ranks should have DIFFERENT compiled code
|
||||
# because they see different device IDs (cuda:0 vs cuda:1)
|
||||
self.assertNotEqual(
|
||||
rank0,
|
||||
rank1,
|
||||
"Expected different compiled code without --virtual-local-rank",
|
||||
)
|
||||
|
||||
# Now run WITH virtual-local-rank - outputs should be identical
|
||||
output = run_test(use_virtual_local_rank=True)
|
||||
rank0, rank1 = split_ranks(output)
|
||||
|
||||
# Verify we actually captured compiled code from both ranks
|
||||
self.assertGreater(
|
||||
len(rank0),
|
||||
0,
|
||||
"Expected to capture compiled code from rank 0 with --virtual-local-rank",
|
||||
)
|
||||
self.assertGreater(
|
||||
len(rank1),
|
||||
0,
|
||||
"Expected to capture compiled code from rank 1 with --virtual-local-rank",
|
||||
)
|
||||
|
||||
# With virtual-local-rank, both ranks should have IDENTICAL compiled code
|
||||
# because they both see cuda:0 during compilation
|
||||
self.assertEqual(
|
||||
rank0, rank1, "Expected identical compiled code with --virtual-local-rank"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -7522,6 +7522,38 @@ class AOTInductorTestsTemplate:
|
||||
eager_outputs = model(*example_inputs)
|
||||
torch.testing.assert_close(eager_outputs, compiled_outputs)
|
||||
|
||||
@requires_gpu
|
||||
def test_mixed_device_1(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("Mixed-device test requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Buffers are on CPU
|
||||
self.register_buffer(
|
||||
"index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64)
|
||||
)
|
||||
self.register_buffer(
|
||||
"src", torch.ones(4, device="cpu", dtype=torch.int64)
|
||||
)
|
||||
|
||||
def forward(self, matrix, vector):
|
||||
# Inputs are on CUDA
|
||||
# 1. Operation on CPU tensors
|
||||
z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64)
|
||||
scatter_result = z.scatter_add(0, self.index, self.src)
|
||||
|
||||
# 2. Move result to CUDA and continue on CUDA
|
||||
v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE)
|
||||
return torch.matmul(matrix, v)
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, device=self.device),
|
||||
)
|
||||
self.check_model(Model(), example_inputs, move_model_to_device=False)
|
||||
|
||||
|
||||
class AOTInductorLoggingTest(LoggingTestCase):
|
||||
@make_logging_test(dynamic=logging.DEBUG)
|
||||
|
||||
@ -218,6 +218,7 @@ def check_model(
|
||||
dynamic_shapes=None,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
move_model_to_device=True,
|
||||
):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
@ -229,7 +230,7 @@ def check_model(
|
||||
),
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
if not isinstance(model, types.FunctionType):
|
||||
if not isinstance(model, types.FunctionType) and move_model_to_device:
|
||||
model = model.to(self.device)
|
||||
|
||||
# For non mixed device inputs with default "cpu",set the device manually.
|
||||
|
||||
@ -20,9 +20,11 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
|
||||
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1, device=None):
|
||||
result = []
|
||||
device = inputs[0].device
|
||||
if len(inputs) != 0:
|
||||
device = inputs[0].device
|
||||
assert device
|
||||
# iterate over the cartesian product of predicate values
|
||||
for values in itertools.product(*([possible_values] * num_to_prepend)):
|
||||
prepended = [torch.tensor(v, device=device) for v in values]
|
||||
@ -30,8 +32,8 @@ def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
|
||||
return result
|
||||
|
||||
|
||||
def prepend_predicates(inputs, num_predicates=1):
|
||||
return _prepend_product_of_values(inputs, [False, True], num_predicates)
|
||||
def prepend_predicates(inputs, num_predicates=1, device=None):
|
||||
return _prepend_product_of_values(inputs, [False, True], num_predicates, device)
|
||||
|
||||
|
||||
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
|
||||
@ -308,7 +310,9 @@ class CondTests(TestCase):
|
||||
torch._dynamo.mark_dynamic(inp, 0)
|
||||
|
||||
for inputs in input_sets:
|
||||
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
|
||||
for inputs_with_predicates in prepend_predicates(
|
||||
inputs, num_predicates, device=device
|
||||
):
|
||||
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
|
||||
result = model(*inputs_with_predicates)
|
||||
result_compiled = compiled_model(*inputs_with_predicates)
|
||||
@ -768,6 +772,26 @@ class CondTests(TestCase):
|
||||
dynamic=dynamic,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
def test_output_on_different_device(self):
|
||||
class FactoryBranches(torch.nn.Module):
|
||||
def forward(self, pred):
|
||||
tensor = torch.cond(
|
||||
pred,
|
||||
lambda: torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).to(
|
||||
GPU_TYPE
|
||||
),
|
||||
lambda: torch.zeros(5, dtype=torch.float32).to(GPU_TYPE),
|
||||
)
|
||||
return tensor + 1
|
||||
|
||||
self._run_test(
|
||||
model=FactoryBranches(),
|
||||
inputs=(),
|
||||
device="cpu", # device for predicate
|
||||
dynamic=True,
|
||||
)
|
||||
|
||||
|
||||
class WhileLoopModels:
|
||||
class Simple(torch.nn.Module):
|
||||
|
||||
@ -726,8 +726,7 @@ class CPUReproTests(TestCase):
|
||||
seq_len,
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len",
|
||||
_test_lstm_packed_change_input_sizes_cpu_params = list(
|
||||
itertools.product(
|
||||
*[
|
||||
[False],
|
||||
@ -741,7 +740,12 @@ class CPUReproTests(TestCase):
|
||||
[2],
|
||||
[3],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@parametrize(
|
||||
"unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len",
|
||||
_test_lstm_packed_change_input_sizes_cpu_params,
|
||||
)
|
||||
def test_lstm_packed_change_input_sizes_cpu(
|
||||
self,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
import functools
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -230,6 +231,33 @@ class PallasTestsMixin:
|
||||
self.assertIn("import jax.numpy as jnp", code)
|
||||
self.assertIn("from jax.experimental import pallas as pl", code)
|
||||
|
||||
def test_jax_jit_wrapper_is_emitted(self):
|
||||
"""Ensure generated Pallas code wraps pl.pallas_call in jax.jit."""
|
||||
|
||||
key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"
|
||||
|
||||
@torch.compile(backend="inductor", options={key: "pallas"})
|
||||
def pallas_fn(a, b):
|
||||
return a + b
|
||||
|
||||
_, (code,) = run_and_get_code(
|
||||
pallas_fn,
|
||||
torch.randn(32, device=self.DEVICE),
|
||||
torch.randn(32, device=self.DEVICE),
|
||||
)
|
||||
|
||||
kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code)
|
||||
self.assertIsNotNone(kernel_match)
|
||||
kernel_name = kernel_match.group(1)
|
||||
wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||
self.assertIn(wrapper_name, code)
|
||||
start = code.index(f"def {wrapper_name}")
|
||||
end = code.index(f"def {kernel_name}_main", start)
|
||||
wrapper_block = code[start:end]
|
||||
|
||||
self.assertIn("jax.jit", code)
|
||||
self.assertNotIn("torch.", wrapper_block)
|
||||
|
||||
def test_2d_tensor(self):
|
||||
"""Test with 2D tensors (though current implementation flattens)."""
|
||||
|
||||
|
||||
@ -7455,6 +7455,34 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
@ -7490,99 +7518,64 @@ class TestFXMemoryProfiler(TestCase):
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# reset cache to start fresh
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
mod = self.MLPModule(device)
|
||||
# reset cache to start fresh
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(augment_with_fx_traces=True)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
self.assertGreater(len(fx_frames), 2)
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
self.assertGreater(len(fx_frames), 2)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
mod = self.MLPModule2(device)
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(augment_with_fx_traces=True)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
|
||||
@ -9465,7 +9465,7 @@ class TestSDPA(TestCaseMPS):
|
||||
torch.manual_seed(1729)
|
||||
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
|
||||
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
|
||||
i = 42
|
||||
i = 42 if S > 42 else S // 2
|
||||
|
||||
q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
|
||||
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
|
||||
|
||||
@ -4506,7 +4506,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError
|
||||
with self.assertRaisesRegex(
|
||||
exc,
|
||||
"Can't (get|pickle) local object (<function |')WeakValueDictionary.__init__.<locals>.remove"
|
||||
r"Can't (get|pickle) local object (<function |')WeakValueDictionary\.__init__\.<locals>\.remove"
|
||||
):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(ft, f)
|
||||
|
||||
@ -974,6 +974,41 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
|
||||
subgraph.recompile()
|
||||
|
||||
|
||||
def _assign_new_node_names(
|
||||
gm: torch.fx.GraphModule,
|
||||
name_map: dict[str, str],
|
||||
custom_meta: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Assign new names to all nodes, in the graph module, from name map.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
assert node.name in name_map
|
||||
node.name = node.target = name_map[node.name]
|
||||
if node.name in custom_meta:
|
||||
if node.meta.get("custom") is None:
|
||||
node.meta["custom"] = {}
|
||||
else:
|
||||
# Assert if any existing key has different value
|
||||
for k, v in node.meta["custom"].items():
|
||||
if (
|
||||
k in custom_meta[node.name]
|
||||
and v != custom_meta[node.name][k]
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Mismatch in custom metadata for key {k}. Value in "
|
||||
f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}."
|
||||
)
|
||||
node.meta["custom"].update(custom_meta[node.name])
|
||||
# if the constant obj is an input, we also need to update meta["val"]
|
||||
# because this is created before the placeholder naming pass
|
||||
if isinstance(node.meta["val"], CustomObjArgument):
|
||||
node.meta["val"].name = node.name
|
||||
elif node.name in name_map:
|
||||
node.name = name_map[node.name]
|
||||
|
||||
|
||||
def placeholder_naming_pass(
|
||||
gm: torch.fx.GraphModule,
|
||||
export_graph_signature: "ExportGraphSignature",
|
||||
@ -1091,31 +1126,7 @@ def placeholder_naming_pass(
|
||||
)
|
||||
|
||||
# assign new node names
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
assert node.name in name_map
|
||||
node.name = node.target = name_map[node.name]
|
||||
if node.name in custom_meta:
|
||||
if node.meta.get("custom") is None:
|
||||
node.meta["custom"] = {}
|
||||
else:
|
||||
# Assert if any existing key has different value
|
||||
for k, v in node.meta["custom"].items():
|
||||
if (
|
||||
k in custom_meta[node.name]
|
||||
and v != custom_meta[node.name][k]
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Mismatch in custom metadata for key {k}. Value in "
|
||||
f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}."
|
||||
)
|
||||
node.meta["custom"].update(custom_meta[node.name])
|
||||
# if the constant obj is an input, we also need to update meta["val"]
|
||||
# because this is created before the placeholder naming pass
|
||||
if isinstance(node.meta["val"], CustomObjArgument):
|
||||
node.meta["val"].name = node.name
|
||||
elif node.name in name_map:
|
||||
node.name = name_map[node.name]
|
||||
_assign_new_node_names(gm, name_map, custom_meta)
|
||||
|
||||
# propagate names to higher order op subgraphs
|
||||
_name_hoo_subgraph_placeholders(gm)
|
||||
|
||||
@ -624,7 +624,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||
try:
|
||||
self.dump(obj)
|
||||
return self._stream.getvalue()
|
||||
except (TypeError, AttributeError) as e:
|
||||
except (TypeError, AttributeError, pickle.PicklingError) as e:
|
||||
# Some configs options may not pickle.
|
||||
log.warning("Failed to pickle cache key", exc_info=True)
|
||||
raise BypassFxGraphCache("Failed to pickle cache key") from e
|
||||
|
||||
@ -221,7 +221,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
"""
|
||||
)
|
||||
|
||||
self.add_device_include(self.device)
|
||||
for device in V.graph.device_types:
|
||||
if device != "meta":
|
||||
self.add_device_include(device)
|
||||
|
||||
if V.graph.aot_mode:
|
||||
if config.aot_inductor.dynamic_linkage:
|
||||
@ -1423,11 +1425,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
self.add_device_include(device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||
inputs_wrapped = [str(x) for x in inputs]
|
||||
|
||||
@ -708,11 +708,14 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
reduce = self._get_scatter_reduce_enum(reduce)
|
||||
|
||||
# call the ABI shim function instead of the ATen one
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device)
|
||||
self.add_device_include(device)
|
||||
cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device)
|
||||
|
||||
# TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py
|
||||
cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out"
|
||||
self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor()
|
||||
|
||||
@ -287,6 +287,7 @@ class PallasKernel(SIMDKernel):
|
||||
code = IndentedBuffer()
|
||||
code.splice(
|
||||
"""
|
||||
import functools
|
||||
import torch
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -301,6 +302,9 @@ class PallasKernel(SIMDKernel):
|
||||
kernel_params = [a.name for a in arg_defs]
|
||||
|
||||
kernel_name = name or "<KERNEL_NAME>"
|
||||
interpret_literal = (
|
||||
"True" if V.graph.get_current_device_or_throw().type == "cpu" else "False"
|
||||
)
|
||||
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
|
||||
with code.indent():
|
||||
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
|
||||
@ -309,16 +313,22 @@ class PallasKernel(SIMDKernel):
|
||||
for line in self.stores._lines:
|
||||
code.writeline(str(line))
|
||||
|
||||
jit_wrapper_name = f"{kernel_name}_jit_wrapper"
|
||||
code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))")
|
||||
code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):")
|
||||
with code.indent():
|
||||
code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)")
|
||||
code.writeline("return pl.pallas_call(")
|
||||
code.writeline(f" {kernel_name}_kernel,")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")(*kernel_refs)")
|
||||
|
||||
# Host entry: convert torch tensors <-> jax, call pallas_call and copy back
|
||||
main_name = f"{kernel_name}_main"
|
||||
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
||||
with code.indent():
|
||||
# Determine interpret statically based on codegen device
|
||||
interpret_literal = (
|
||||
"True"
|
||||
if V.graph.get_current_device_or_throw().type == "cpu"
|
||||
else "False"
|
||||
)
|
||||
# Identify inputs (in_ptr*) and output (out_ptr*)
|
||||
input_params = [
|
||||
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
||||
@ -337,9 +347,9 @@ class PallasKernel(SIMDKernel):
|
||||
for inp in input_params:
|
||||
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
||||
|
||||
# Get output spec from PyTorch tensor
|
||||
code.writeline("# Prepare output spec from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype string")
|
||||
# Get output metadata from PyTorch tensor
|
||||
code.writeline("# Prepare output metadata from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype")
|
||||
code.writeline("_torch_dtype_to_jax = {")
|
||||
code.writeline(
|
||||
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
|
||||
@ -349,21 +359,14 @@ class PallasKernel(SIMDKernel):
|
||||
)
|
||||
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
|
||||
code.writeline("}")
|
||||
code.writeline(
|
||||
f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])"
|
||||
)
|
||||
code.writeline(f"out_shape = tuple({output_param}.shape)")
|
||||
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
|
||||
|
||||
# Call pallas
|
||||
# Pass interpret=True on CPU, False otherwise (single call, no duplication)
|
||||
code.writeline("compiled = pl.pallas_call(")
|
||||
code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),")
|
||||
code.writeline(" out_shape=out_spec,")
|
||||
code.writeline(f" interpret={interpret_literal},")
|
||||
code.writeline(" grid=(1,),")
|
||||
code.writeline(")")
|
||||
|
||||
jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params])
|
||||
code.writeline(f"res = compiled({jax_input_args})")
|
||||
call_args = ["out_shape", "out_dtype"] + [
|
||||
f"{inp}_jax" for inp in input_params
|
||||
]
|
||||
call_arg_str = ", ".join(call_args)
|
||||
code.writeline(f"res = {jit_wrapper_name}({call_arg_str})")
|
||||
|
||||
# Copy result back
|
||||
code.writeline("# Copy result back into the provided torch output tensor")
|
||||
|
||||
@ -971,6 +971,7 @@ class ScatterFallbackLine(WrapperLine):
|
||||
else:
|
||||
(x, index) = (t.codegen_reference() for t in node.inputs)
|
||||
src = node.constant_args[1]
|
||||
device = d.type if (d := node.get_device()) else V.graph.device_type
|
||||
self.wrapper._generate_scatter_fallback(
|
||||
x,
|
||||
[x, node.constant_args[0], index, src],
|
||||
@ -979,6 +980,7 @@ class ScatterFallbackLine(WrapperLine):
|
||||
node.src_is_tensor,
|
||||
node.kwargs["reduce"],
|
||||
node.codegen_kwargs(),
|
||||
device,
|
||||
)
|
||||
|
||||
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
|
||||
@ -1632,6 +1634,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
src_is_tensor,
|
||||
reduce,
|
||||
kwargs,
|
||||
device,
|
||||
):
|
||||
line = f"{python_kernel_name}({','.join(map(str, inputs))}"
|
||||
if python_kernel_name.startswith("aten.scatter_reduce"):
|
||||
|
||||
@ -468,6 +468,8 @@ class _SerializedFxCompile(FxCompile):
|
||||
fake_mode = _current_fake_mode()
|
||||
fake_tensor_mode = _FakeTensorModeSerializer(fake_mode)
|
||||
|
||||
from pickle import PicklingError
|
||||
|
||||
try:
|
||||
input = _WireProtocolInput(
|
||||
gm,
|
||||
@ -483,7 +485,7 @@ class _SerializedFxCompile(FxCompile):
|
||||
fake_tensor_mode,
|
||||
).serialize()
|
||||
return (input, constants)
|
||||
except (AttributeError, BypassFxGraphCache):
|
||||
except (AttributeError, BypassFxGraphCache, PicklingError):
|
||||
# For example: AttributeError: Can't pickle local object
|
||||
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
|
||||
|
||||
|
||||
@ -8845,7 +8845,9 @@ class Conditional(ExternKernel):
|
||||
outputs = [
|
||||
MultiOutput(
|
||||
FixedLayout(
|
||||
device=device,
|
||||
device=output.get_device()
|
||||
if output.get_device() is not None
|
||||
else device, # type: ignore[arg-type]
|
||||
dtype=output.get_dtype(),
|
||||
size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
|
||||
stride=[
|
||||
|
||||
@ -48,7 +48,8 @@ logger = get_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class WorkerSpec:
|
||||
"""Blueprint information about a particular type of worker.
|
||||
"""
|
||||
Blueprint information about a particular type of worker.
|
||||
|
||||
For a given role, there must only exist a single worker spec.
|
||||
Worker spec is expected to be homogeneous across all nodes (machine),
|
||||
@ -79,6 +80,10 @@ class WorkerSpec:
|
||||
that match _any_ of the filter strings.
|
||||
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
|
||||
that match _any_ of the filter strings.
|
||||
virtual_local_rank: Enable virtual local rank mode for workers (defaults to False).
|
||||
When enabled, LOCAL_RANK is set to 0 for all workers and
|
||||
CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its
|
||||
assigned GPU at device index 0.
|
||||
"""
|
||||
|
||||
role: str
|
||||
@ -97,6 +102,7 @@ class WorkerSpec:
|
||||
numa_options: Optional[NumaOptions] = None
|
||||
duplicate_stdout_filters: Optional[list[str]] = None
|
||||
duplicate_stderr_filters: Optional[list[str]] = None
|
||||
virtual_local_rank: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.local_world_size > 0
|
||||
|
||||
@ -303,7 +303,6 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
for worker in worker_group.workers:
|
||||
local_rank = worker.local_rank
|
||||
worker_env = {
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"RANK": str(worker.global_rank),
|
||||
"GROUP_RANK": str(worker_group.group_rank),
|
||||
"ROLE_RANK": str(worker.role_rank),
|
||||
@ -322,6 +321,7 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
|
||||
),
|
||||
}
|
||||
self._set_local_rank_env(worker_env, local_rank, spec)
|
||||
if "OMP_NUM_THREADS" in os.environ:
|
||||
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
|
||||
|
||||
@ -362,6 +362,46 @@ class LocalElasticAgent(SimpleElasticAgent):
|
||||
|
||||
return self._pcontext.pids()
|
||||
|
||||
def _set_local_rank_env(
|
||||
self, worker_env: dict[str, str | None], local_rank: int, spec: WorkerSpec
|
||||
) -> None:
|
||||
# Set CUDA_VISIBLE_DEVICES and LOCAL_RANK based on virtual_local_rank mode.
|
||||
# Virtual mode: Each worker sees only its assigned GPU as device 0, LOCAL_RANK=0
|
||||
# Traditional mode: Workers see all GPUs, LOCAL_RANK matches actual local rank
|
||||
|
||||
if spec.virtual_local_rank:
|
||||
# Set LOCAL_RANK=0 and use CUDA_VISIBLE_DEVICES to control the actual GPU access.
|
||||
|
||||
worker_env["LOCAL_RANK"] = "0"
|
||||
|
||||
# Map local_rank through existing CUDA_VISIBLE_DEVICES
|
||||
# HIP uses CUDA_VISIBLE_DEVICES as a compatibility hack:
|
||||
# https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#cuda-visible-devices
|
||||
parent_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
if parent_visible_devices is not None:
|
||||
# Parse comma-separated list of GPU IDs
|
||||
available_gpus = parent_visible_devices.split(",")
|
||||
if local_rank >= len(available_gpus):
|
||||
raise ValueError(
|
||||
f"local_rank {local_rank} exceeds available GPUs in "
|
||||
f"CUDA_VISIBLE_DEVICES={parent_visible_devices}"
|
||||
)
|
||||
|
||||
visible_gpu = available_gpus[local_rank].strip()
|
||||
else:
|
||||
# No restriction, use local_rank directly
|
||||
visible_gpu = str(local_rank)
|
||||
|
||||
worker_env["CUDA_VISIBLE_DEVICES"] = visible_gpu
|
||||
return
|
||||
|
||||
# In traditional mode, don't override CUDA_VISIBLE_DEVICES
|
||||
# (inherit from parent environment)
|
||||
worker_env["LOCAL_RANK"] = str(local_rank)
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
worker_env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
|
||||
if self._worker_watchdog is not None:
|
||||
self._worker_watchdog.stop()
|
||||
|
||||
@ -75,6 +75,10 @@ class LaunchConfig:
|
||||
that match _any_ of the filter strings.
|
||||
duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
|
||||
that match _any_ of the filter strings.
|
||||
virtual_local_rank: Enable virtual local rank mode for workers (defaults to False).
|
||||
When enabled, LOCAL_RANK is set to 0 for all workers and
|
||||
CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its
|
||||
assigned GPU at device index 0.
|
||||
|
||||
|
||||
.. note::
|
||||
@ -104,6 +108,7 @@ class LaunchConfig:
|
||||
signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
|
||||
duplicate_stdout_filters: Optional[list[str]] = None
|
||||
duplicate_stderr_filters: Optional[list[str]] = None
|
||||
virtual_local_rank: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
default_timeout = 900
|
||||
@ -288,6 +293,7 @@ def launch_agent(
|
||||
numa_options=config.numa_options,
|
||||
duplicate_stdout_filters=config.duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=config.duplicate_stderr_filters,
|
||||
virtual_local_rank=config.virtual_local_rank,
|
||||
)
|
||||
|
||||
agent = LocalElasticAgent(
|
||||
|
||||
@ -688,6 +688,15 @@ def get_args_parser() -> ArgumentParser:
|
||||
"Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--virtual-local-rank",
|
||||
"--virtual_local_rank",
|
||||
action=check_env,
|
||||
help="Enable virtual local rank mode for workers. When enabled, LOCAL_RANK is set to 0 "
|
||||
"for all workers and CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its "
|
||||
"assigned GPU at device index 0.",
|
||||
)
|
||||
|
||||
#
|
||||
# Positional arguments.
|
||||
#
|
||||
@ -907,6 +916,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
||||
signals_to_handle=args.signals_to_handle,
|
||||
duplicate_stdout_filters=args.duplicate_stdout_filters,
|
||||
duplicate_stderr_filters=args.duplicate_stderr_filters,
|
||||
virtual_local_rank=args.virtual_local_rank,
|
||||
)
|
||||
|
||||
with_python = not args.no_python
|
||||
|
||||
@ -256,7 +256,7 @@ def bmm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
|
||||
|
||||
@register_op_strategy(aten.baddbmm.default)
|
||||
def baddmm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
def baddbmm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
mesh = op_schema.get_mesh_from_args()
|
||||
return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
|
||||
|
||||
|
||||
@ -177,6 +177,26 @@ def split_const_subgraphs(
|
||||
else:
|
||||
mod_traced = module
|
||||
|
||||
def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Return True if a GraphModule type subgraph contains any impure op, else False.
|
||||
"""
|
||||
assert isinstance(module, torch.fx.GraphModule), (
|
||||
"caller should only pass GraphModule to subgraph_has_impure_ops check"
|
||||
)
|
||||
for node in module.graph.nodes:
|
||||
if node.op == "call_function" and node.is_impure():
|
||||
return True
|
||||
if (
|
||||
# pyrefly: ignore [invalid-argument]
|
||||
node.op == "call_module"
|
||||
# pyrefly: ignore [not-callable]
|
||||
and (submodule := module.get_submodule(node.target))
|
||||
and isinstance(submodule, torch.fx.GraphModule)
|
||||
):
|
||||
return _subgraph_has_impure_ops(submodule)
|
||||
return False
|
||||
|
||||
# Build up a list of const_nodes, defined as nodes that are themselves
|
||||
# get_attrs, or have all get_attr or other constant node inputs.
|
||||
const_nodes: set[torch.fx.Node] = set()
|
||||
@ -206,6 +226,17 @@ def split_const_subgraphs(
|
||||
if isinstance(node.kwargs.get("fill_value", None), sympy.Expr):
|
||||
continue
|
||||
|
||||
# Skip folding submodules that have impure ops
|
||||
if (
|
||||
# pyrefly: ignore [invalid-argument]
|
||||
node.op == "call_module"
|
||||
# pyrefly: ignore [not-callable]
|
||||
and (target_mod := mod_traced.get_submodule(node.target))
|
||||
and isinstance(target_mod, torch.fx.GraphModule)
|
||||
and _subgraph_has_impure_ops(target_mod)
|
||||
):
|
||||
continue
|
||||
|
||||
# Must be a constant foldable node at this point.
|
||||
const_nodes.add(node)
|
||||
if node.op != "get_attr":
|
||||
|
||||
@ -754,26 +754,6 @@ class Node(_NodeBase):
|
||||
|
||||
return self.target in _side_effectful_functions
|
||||
|
||||
def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Return True if a GraphModule type subgraph contains any impure op, else False.
|
||||
"""
|
||||
assert isinstance(module, torch.fx.GraphModule), (
|
||||
"caller should only pass GraphModule to subgraph_has_impure_ops check"
|
||||
)
|
||||
for node in module.graph.nodes:
|
||||
if node.op == "call_function" and node.is_impure(impure_random):
|
||||
return True
|
||||
if (
|
||||
# pyrefly: ignore [invalid-argument]
|
||||
node.op == "call_module"
|
||||
# pyrefly: ignore [not-callable]
|
||||
and (submodule := module.get_submodule(node.target))
|
||||
and isinstance(submodule, torch.fx.GraphModule)
|
||||
):
|
||||
return subgraph_has_impure_ops(submodule)
|
||||
return False
|
||||
|
||||
# Check if an impure module.
|
||||
if self.op == "call_module":
|
||||
assert self.graph.owning_module is not None, (
|
||||
@ -783,10 +763,11 @@ class Node(_NodeBase):
|
||||
assert target_mod is not None, (
|
||||
f"Did not find expected submodule target {self.target}"
|
||||
)
|
||||
if isinstance(target_mod, torch.fx.GraphModule):
|
||||
return subgraph_has_impure_ops(target_mod)
|
||||
else:
|
||||
return getattr(target_mod, "_is_impure", False)
|
||||
# NOTE: here we can end up considering GraphModule submodules pure,
|
||||
# even if they contain impure ops. It may not be safe to change
|
||||
# because this function is used by graph.eliminate_dead_code,
|
||||
# and some users depend on current elimination behavior.
|
||||
return getattr(target_mod, "_is_impure", False)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -1711,7 +1711,7 @@ class MultiProcContinuousTest(TestCase):
|
||||
@classmethod
|
||||
def _init_pg(cls, rank, world_size, rdvz_file):
|
||||
assert rdvz_file is not None
|
||||
# rank should be local_rank for tests running on <= 8gpus which is how all these tests are designed
|
||||
# rank should be local_rank for tests running on <= 8 gpus which is how all these tests are designed
|
||||
# and we expect LOCAL_RANK set by torchrun. Setting it lets init_device_mesh set the device without
|
||||
# issuing a warning
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
|
||||
Reference in New Issue
Block a user