Compare commits

...

18 Commits

Author SHA1 Message Date
c93b821875 adding documentation 2025-11-07 16:31:53 -08:00
4957ae5838 Add API to annotate disjoint backward and handle in AC (#166536)
This adds zero-bubble / DualPipeV support for (S)AC

Before:
- AC will always retrigger recompute upon every distinct backward.

After:
- Any checkpointed regions encountered by backward under the same instance of this context manager will only trigger recompute at most once, even if there are multiple calls to backward.
- Backward calls under the same instance of this context manager must execute over non-overlapping regions of the backward graph even if retain_graph=True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166536
Approved by: https://github.com/albanD
2025-11-08 00:21:25 +00:00
31d6d3ef5c [easy] Add new torch/csrc/stable/c/shim.h to existing nitpick (#167367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167367
Approved by: https://github.com/janeyx99, https://github.com/malfet
2025-11-08 00:13:03 +00:00
2325c511e7 [dynamo] Make sym node vt creation via SymNodeVariable create (#167189)
This will help in the next PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167189
Approved by: https://github.com/williamwen42, https://github.com/zou3519
ghstack dependencies: #167160
2025-11-07 23:58:13 +00:00
d865156967 [dynamo][hops] Overwrite proxy of the original VT to the subgraph outputs (#167160)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167160
Approved by: https://github.com/zou3519
2025-11-07 23:58:13 +00:00
fbc0bd2e90 [DTensor][be] getting rid of unneccesary Partial check for norm functions (#167247)
**Summary:** While the implementation is correct, these checks are just a subset of the Partial placement checks that are done in https://github.com/pytorch/pytorch/pull/165962. This means for ops aten.linalg_vector_norm.default and aten._foreach_norm.Scalar, we're unnecessarily checking for Partial placements twice.

**Test Cases**
1. pytest test/distributed/tensor/test_math_ops.py -k test_vector_norm_partial
2. pytest test/distributed/tensor/test_math_ops.py -k test_foreach_norm_partial
3. pytest test/distributed/tensor/test_math_ops.py -k test_partial_reduction_ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167247
Approved by: https://github.com/XilunWu
2025-11-07 23:49:29 +00:00
70f5f55abf [Inductor-FX] Allocate tensors on device type instead of indexed device (#167358)
# Problem
The FX backend currently allocates tensors on an exact device index, such as `"cuda:0"`. In contrast, the Python backend allocates on a device type, such as `"cuda"`. This avoids edge cases where fake tensor propagation can fail due to mismatched devices.

# Fix
Allocate tensors on `device.type` instead of the device.

# Test plan
Added a CI test passing in sample inputs on an indexed device, and checking that the output device in the generated FX graph is not indexed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167358
Approved by: https://github.com/mlazos, https://github.com/nandesuka, https://github.com/eellison
2025-11-07 23:48:54 +00:00
69ecb562e7 [PT2 Compiler] Add annotation for dynamo disabled callables (#166341)
Summary: To make torch.export compatible with PT2 compile (which is done on top of exported model) we need to store torch._dynamo.disable attributes in exported model and later restore this after unflattening of exported model. This diff will add annotations to all nodes with torch._dynamo.disable, which will be preserved during exporting.

Test Plan:
```
buck test mode/opt caffe2/test:test_export -- 'test_dynamo_disable_annotations'
```
https://www.internalfb.com/intern/testinfra/testrun/6473924770741560

Differential Revision: D85302730

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166341
Approved by: https://github.com/williamwen42, https://github.com/angelayi
2025-11-07 23:28:00 +00:00
5062abe4e7 [CI][serialization] Fix exception regexes with Python-3.14 (#167333)
Not sure why, but running some tests (for example `test_weights_only_safe_globals_build`) with `pytest` in 3.14 makes global name `test_serialization.ClassThatUsesBuildInstruction` instead of expected `__main__.ClassThatUsesBuildInstruction`
Also, change expected exception type from `AttributeError` to `PicklingError`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167333
Approved by: https://github.com/atalman
2025-11-07 23:22:36 +00:00
c7007e7584 Update Kineto Submodule (#167343)
Summary: Title

Test Plan: CI

Differential Revision: D86538778

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167343
Approved by: https://github.com/Skylion007, https://github.com/aaronenyeshi
2025-11-07 23:06:58 +00:00
09705ca9b2 [dynamo][guards] Fix mem leak in tensor subclass metadata guard (#167352)
Use cls instead of the object. Earlier the metadata guard was holding on
to the Dtensor causing mem leak.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167352
Approved by: https://github.com/Skylion007
2025-11-07 23:01:15 +00:00
ea6b0b5d0f add missing cpp standard lib in HeaderOnlyArrayRef.h (#167337)
Fixes #167315
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167337
Approved by: https://github.com/janeyx99
2025-11-07 23:00:08 +00:00
bbf852d87f Revert "Remove python workaround for ContextDecorator (#167049)"
This reverts commit 13d2cc7bd26e32cafff0377dda1c5ddc8d04c4ce.

Reverted https://github.com/pytorch/pytorch/pull/167049 on behalf of https://github.com/donigian due to breaking internal tests D86342845 ([comment](https://github.com/pytorch/pytorch/pull/167049#issuecomment-3505251296))
2025-11-07 22:32:45 +00:00
6392b986e7 Revert "[13/N] Apply ruff UP035 rule (#167048)"
This reverts commit ea44f12bce3eb05eaa9fa34943a3ffae04647fa5.

Reverted https://github.com/pytorch/pytorch/pull/167048 on behalf of https://github.com/donigian due to breaking internal tests D86342860 ([comment](https://github.com/pytorch/pytorch/pull/167048#issuecomment-3505232522))
2025-11-07 22:25:01 +00:00
32d30d96cf [ROCm][CI] unconditionally add gfx950, gfx115x to PYTORCH_ROCM_ARCH (#167299)
Included gfx950, gfx1150, and gfx1151 unconditionally in PYTORCH_ROCM_ARCH. Removed the ROCm 7.0 version check and refactored the architecture list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167299
Approved by: https://github.com/jeffdaily
2025-11-07 21:47:59 +00:00
46516efa85 [BE] use undeprecated from/to in libtorch_agnostic tests (#167126)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167126
Approved by: https://github.com/Skylion007
ghstack dependencies: #164991, #165152, #165153, #165953
2025-11-07 21:31:30 +00:00
84b2147b85 Introducing the StableIValue representation of list :D (#165953)
Some important notes:
a) Just like IValues steal the ownership of ArrayRefs and any std::vectors in order to convert the inner elements into IValues, we do the same thing with StableIValue. This O(N) traverse is ineluctable.
b) As a result, since StableIValues are owning and our contract is that to<T>(StableIValue) transfers ownership, you cannot ever convert from StableIValue to a nonowning HeaderOnlyArrayRef<V>.

We handle memory similar to AtenTensorHandle, but we have a StableListHandle!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165953
Approved by: https://github.com/malfet
ghstack dependencies: #164991, #165152, #165153
2025-11-07 21:31:30 +00:00
1727a71cb6 Create pallas test shard (#167143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167143
Approved by: https://github.com/malfet
ghstack dependencies: #167243
2025-11-07 21:05:54 +00:00
54 changed files with 1066 additions and 291 deletions

View File

@ -36,11 +36,7 @@ case ${DOCKER_TAG_PREFIX} in
;;
rocm*)
BASE_TARGET=rocm
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
;;
*)

View File

@ -260,6 +260,12 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
CUDA_VERSION=13.0.0
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes
;;
pytorch-linux-jammy-py3.12-triton-cpu)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
@ -381,6 +387,7 @@ docker build \
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
--build-arg "EXECUTORCH=${EXECUTORCH}" \
--build-arg "HALIDE=${HALIDE}" \
--build-arg "PALLAS=${PALLAS}" \
--build-arg "XPU_VERSION=${XPU_VERSION}" \
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
--build-arg "ACL=${ACL:-}" \

View File

@ -0,0 +1 @@
0.8.0

View File

@ -0,0 +1,40 @@
#!/bin/bash
set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
# Get the pinned JAX version (same for all CUDA versions)
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
function install_jax_12() {
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
}
function install_jax_13() {
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
}
# idiomatic parameter and option handling in sh
while test $# -gt 0
do
case "$1" in
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
;;
13.0|13.0.*) install_jax_13;
;;
*) echo "bad argument $1"; exit 1
;;
esac
shift
done

View File

@ -49,11 +49,7 @@ case ${DOCKER_TAG_PREFIX} in
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
;;
*)

View File

@ -87,11 +87,7 @@ case ${image} in
MANY_LINUX_VERSION="2_28"
DEVTOOLSET_VERSION="11"
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
;;
manylinux2_28-builder:xpu)

View File

@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
RUN rm install_halide.sh common_utils.sh halide.txt
ARG PALLAS
ARG CUDA_VERSION
# Install JAX with CUDA support (for Pallas)
COPY ./common/install_jax.sh install_jax.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
ARG ONNX
# Install ONNX dependencies
COPY ./common/install_onnx.sh ./common/common_utils.sh ./

View File

@ -824,6 +824,11 @@ test_inductor_halide() {
assert_git_not_dirty
}
test_inductor_pallas() {
python test/run_test.py --include inductor/test_pallas.py --verbose
assert_git_not_dirty
}
test_inductor_triton_cpu() {
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
assert_git_not_dirty
@ -1724,6 +1729,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
test_inductor_halide
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
test_inductor_pallas
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then

View File

@ -10,3 +10,4 @@
pathFilter:
- 'torch/csrc/inductor/aoti_torch/c/*'
- 'torch/csrc/inductor/aoti_torch/generated/*'
- 'torch/csrc/stable/c/*'

View File

@ -67,6 +67,7 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,

View File

@ -81,6 +81,32 @@ jobs:
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
secrets: inherit
inductor-pallas-build:
name: inductor-pallas-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
secrets: inherit
inductor-pallas-test:
name: inductor-pallas-test
uses: ./.github/workflows/_linux-test.yml
needs: inductor-pallas-build
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
secrets: inherit
inductor-triton-cpu-build:
name: inductor-triton-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -1,6 +1,8 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <optional>
namespace c10 {
@ -15,7 +17,8 @@ struct C10_API AutogradState {
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: grad_mode_(grad_mode),
: graph_exec_group_(std::nullopt),
grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
@ -41,6 +44,10 @@ struct C10_API AutogradState {
view_replay_enabled_ = view_replay_enabled;
}
void set_graph_exec_group(std::optional<SafePyObject> group) {
graph_exec_group_ = std::move(group);
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -61,7 +68,12 @@ struct C10_API AutogradState {
return view_replay_enabled_;
}
const std::optional<SafePyObject>& get_graph_exec_group() const {
return graph_exec_group_;
}
private:
std::optional<SafePyObject> graph_exec_group_;
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;

View File

@ -382,20 +382,6 @@ coverage_ignore_functions = [
# torch.ao.quantization.backend_config.tensorrt
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
# torch.ao.quantization.backend_config.utils
"entry_to_pretty_str",
"get_fused_module_classes",
"get_fuser_method_mapping",
"get_fusion_pattern_to_extra_inputs_getter",
"get_fusion_pattern_to_root_node_getter",
"get_module_to_qat_module",
"get_pattern_to_dtype_configs",
"get_pattern_to_input_type_to_index",
"get_qat_module_classes",
"get_root_module_to_quantized_reference_module",
"pattern_to_human_readable",
"remove_boolean_dispatch_from_name",
# torch.ao.quantization.backend_config.x86
"get_x86_backend_config",
# torch.ao.quantization.fuse_modules
"fuse_known_modules",
@ -426,25 +412,6 @@ coverage_ignore_functions = [
"insert_observers_for_model",
"prepare",
"propagate_dtypes_for_known_nodes",
# torch.ao.quantization.fx.utils
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"node_arg_is_bias",
"node_arg_is_weight",
"return_arg_list",
# torch.ao.quantization.pt2e.graph_utils
"bfs_trace_with_node_process",
"find_sequential_partitions",
"get_equivalent_types",
@ -860,80 +827,10 @@ coverage_ignore_functions = [
"get_latency_of_one_partition",
"get_latency_of_partitioned_graph",
"get_partition_to_latency_mapping",
# torch.fx.experimental.proxy_tensor
"decompose",
"disable_autocast_cache",
"disable_proxy_modes_tracing",
"dispatch_trace",
"extract_val",
"fake_signature",
"fetch_sym_proxy",
"fetch_object_proxy",
"get_innermost_proxy_mode",
"get_isolated_graphmodule",
"get_proxy_slot",
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_handle_decomp",
"proxy_call",
"set_meta",
"set_original_aten_op",
"set_proxy_slot",
"snapshot_fake",
"thunkify",
"track_tensor",
"track_tensor_tree",
"wrap_key",
"wrapper_and_args_for_make_fx",
# torch.fx.experimental.recording
"record_shapeenv_event",
"replay_shape_env_events",
"shape_env_check_state_equal",
# torch.fx.experimental.sym_node
"ceil_impl",
"floor_ceil_helper",
"floor_impl",
"method_to_operator",
"sympy_is_channels_last_contiguous_2d",
"sympy_is_channels_last_contiguous_3d",
"sympy_is_channels_last_strides_2d",
"sympy_is_channels_last_strides_3d",
"sympy_is_channels_last_strides_generic",
"sympy_is_contiguous",
"sympy_is_contiguous_generic",
"to_node",
"wrap_node",
"sym_sqrt",
# torch.fx.experimental.symbolic_shapes
"bind_symbols",
"cast_symbool_to_symint_guardless",
"create_contiguous",
"error",
"eval_guards",
"eval_is_non_overlapping_and_dense",
"expect_true",
"find_symbol_binding_fx_nodes",
"free_symbols",
"free_unbacked_symbols",
"fx_placeholder_targets",
"fx_placeholder_vals",
"guard_bool",
"guard_float",
"guard_int",
"guard_scalar",
"has_hint",
"has_symbolic_sizes_strides",
"is_channels_last_contiguous_2d",
"is_channels_last_contiguous_3d",
"is_channels_last_strides_2d",
"is_channels_last_strides_3d",
"is_contiguous",
"is_non_overlapping_and_dense_indicator",
"is_nested_int",
"is_symbol_binding_fx_node",
"is_symbolic",
# torch.fx.experimental.unification.core
"reify",
# torch.fx.experimental.unification.match
"edge",
@ -971,24 +868,6 @@ coverage_ignore_functions = [
"reverse_dict",
# torch.fx.experimental.unification.multipledispatch.variadic
"isvariadic",
# torch.fx.experimental.unification.unification_tools
"assoc",
"assoc_in",
"dissoc",
"first",
"get_in",
"getter",
"groupby",
"itemfilter",
"itemmap",
"keyfilter",
"keymap",
"merge",
"merge_with",
"update_in",
"valfilter",
"valmap",
# torch.fx.experimental.unification.utils
"freeze",
"hashable",
"raises",

View File

@ -12,6 +12,37 @@ These APIs are experimental and subject to change without notice.
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
```
## torch.fx.experimental.sym_node
```{eval-rst}
.. currentmodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. automodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
is_channels_last_contiguous_2d
is_channels_last_contiguous_3d
is_channels_last_strides_2d
is_channels_last_strides_3d
is_contiguous
is_non_overlapping_and_dense_indicator
method_to_operator
sympy_is_channels_last_contiguous_2d
sympy_is_channels_last_contiguous_3d
sympy_is_channels_last_strides_2d
sympy_is_channels_last_strides_3d
sympy_is_channels_last_strides_generic
sympy_is_contiguous
sympy_is_contiguous_generic
```
## torch.fx.experimental.symbolic_shapes
```{eval-rst}
@ -69,6 +100,25 @@ These APIs are experimental and subject to change without notice.
rebind_unbacked
resolve_unbacked_bindings
is_accessor_node
cast_symbool_to_symint_guardless
create_contiguous
error
eval_guards
eval_is_non_overlapping_and_dense
find_symbol_binding_fx_nodes
free_symbols
free_unbacked_symbols
fx_placeholder_targets
fx_placeholder_vals
guard_bool
guard_float
guard_int
guard_scalar
has_hint
has_symbolic_sizes_strides
is_nested_int
is_symbol_binding_fx_node
is_symbolic
```
## torch.fx.experimental.proxy_tensor
@ -91,4 +141,46 @@ These APIs are experimental and subject to change without notice.
get_proxy_mode
maybe_enable_thunkify
maybe_disable_thunkify
decompose
disable_autocast_cache
disable_proxy_modes_tracing
extract_val
fake_signature
fetch_object_proxy
fetch_sym_proxy
has_proxy_slot
is_sym_node
maybe_handle_decomp
proxy_call
set_meta
set_original_aten_op
set_proxy_slot
snapshot_fake
```
## torch.fx.experimental.unification.unification_tools
```{eval-rst}
.. currentmodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. automodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
assoc
assoc_in
dissoc
first
keyfilter
keymap
merge
merge_with
update_in
valfilter
valmap

View File

@ -1134,7 +1134,6 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter
.. py:module:: torch.fx.experimental.schema_type_annotation
.. py:module:: torch.fx.experimental.sym_node
.. py:module:: torch.fx.experimental.unification.core
.. py:module:: torch.fx.experimental.unification.dispatch
.. py:module:: torch.fx.experimental.unification.match
@ -1144,7 +1143,6 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
.. py:module:: torch.fx.experimental.unification.unification_tools
.. py:module:: torch.fx.experimental.unification.utils
.. py:module:: torch.fx.experimental.unification.variable
.. py:module:: torch.fx.experimental.unify_refinements

View File

@ -134,6 +134,23 @@ Quantization to work with this as well.
ObservationType
```
## torch.ao.quantization.backend_config.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.backend_config.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
entry_to_pretty_str
pattern_to_human_readable
remove_boolean_dispatch_from_name
```
## torch.ao.quantization.fx.custom_config
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
@ -154,6 +171,30 @@ This module contains a few CustomConfig classes that's used in both eager mode a
StandaloneModuleConfigEntry
```
## torch.ao.quantization.fx.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.fx.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
all_node_args_except_first
all_node_args_have_no_tensors
collect_producer_nodes
create_getattr_from_value
create_node_from_old_node_preserve_meta
graph_module_from_producer_nodes
maybe_get_next_module
node_arg_is_bias
node_arg_is_weight
return_arg_list
```
## torch.ao.quantization.quantizer
```{eval-rst}

View File

@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = sgd_out_of_place(
to<Tensor>(stack[0]),
to<Tensor>(stack[1]),
float(to<double>(stack[2])),
to<double>(stack[3]),
to<bool>(stack[4]));
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<Tensor>(stack[1]),
float(torch::stable::detail::to<double>(stack[2])),
torch::stable::detail::to<double>(stack[3]),
torch::stable::detail::to<bool>(stack[4]));
stack[0] = from(res);
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
}
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = identity(to<Tensor>(stack[0]));
stack[0] = from(res);
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
Tensor my_abs(Tensor t) {
const auto num_args = 1;
StableIValue stack[num_args];
stack[0] = from(t);
stack[0] = torch::stable::detail::from(t);
aoti_torch_call_dispatcher("aten::abs", "", stack);
return to<Tensor>(stack[0]);
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
stack[0] = from(tensor_res);
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
auto mf = aoti_torch_memory_format_contiguous_format();
stack[0] = from(t);
stack[1] = from(std::optional(t.scalar_type())); // dtype
stack[2] = from(std::nullopt); // layout
stack[3] = from(std::optional(device)); // device
stack[4] = from(std::optional(false)); // pin_memory
stack[5] = from(std::optional(mf)); // memory_format
stack[0] = torch::stable::detail::from(t);
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
stack[2] = torch::stable::detail::from(std::nullopt); // layout
stack[3] = torch::stable::detail::from(std::optional(device)); // device
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
return to<Tensor>(stack[0]);
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
stack[0] = from(res);
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
StableIValue stack_exp[1];
stack_exp[0] = from(t1);
stack_exp[0] = torch::stable::detail::from(t1);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
StableIValue stack_neg[1];
stack_neg[0] = from(t2);
stack_neg[0] = torch::stable::detail::from(t2);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
StableIValue stack_is_leaf[1];
stack_is_leaf[0] = from(t3);
stack_is_leaf[0] = torch::stable::detail::from(t3);
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
return std::make_tuple(
to<Tensor>(stack_exp[0]),
to<Tensor>(stack_neg[0]),
to<bool>(stack_is_leaf[0]));
torch::stable::detail::to<Tensor>(stack_exp[0]),
torch::stable::detail::to<Tensor>(stack_neg[0]),
torch::stable::detail::to<bool>(stack_is_leaf[0]));
}
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
stack[0] = from(std::get<0>(tuple));
stack[1] = from(std::get<1>(tuple));
stack[2] = from(std::get<2>(tuple));
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
Tensor neg_exp(Tensor t) {
StableIValue stack[1];
stack[0] = from(t);
stack[0] = torch::stable::detail::from(t);
aoti_torch_call_dispatcher("aten::exp", "", stack);
aoti_torch_call_dispatcher("aten::neg", "", stack);
return to<Tensor>(stack[0]);
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = neg_exp(to<Tensor>(stack[0]));
stack[0] = from(res);
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
Tensor divide_neg_exp(Tensor t) {
StableIValue stack_neg[1];
stack_neg[0] = from(t);
stack_neg[0] = torch::stable::detail::from(t);
StableIValue stack_exp[1];
stack_exp[0] = from(t);
stack_exp[0] = torch::stable::detail::from(t);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
stack_div[0] = stack_neg[0];
stack_div[1] = stack_exp[0];
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
return to<Tensor>(stack_div[0]);
return torch::stable::detail::to<Tensor>(stack_div[0]);
}
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
stack[0] = from(res);
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
}
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
bool res = is_contiguous(to<Tensor>(stack[0]));
stack[0] = from(res);
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
}
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
stack[0] = from(res);
stack[0] = torch::stable::detail::from(res);
}
Tensor my_empty_like(Tensor t) {
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
}
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_empty_like(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
bool my_is_cpu(Tensor t) {
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_is_cpu(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor fill_infinity(Tensor t) {
@ -296,8 +296,8 @@ void boxed_fill_infinity(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = fill_infinity(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_pad(Tensor t) {
@ -310,8 +310,8 @@ void boxed_my_pad(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_pad(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
@ -323,11 +323,11 @@ void boxed_my_narrow(
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_narrow(
to<Tensor>(stack[0]),
to<int64_t>(stack[1]),
to<int64_t>(stack[2]),
to<int64_t>(stack[3]));
stack[0] = from(res);
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<int64_t>(stack[1]),
torch::stable::detail::to<int64_t>(stack[2]),
torch::stable::detail::to<int64_t>(stack[3]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_new_empty_dtype_variant(Tensor t) {
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
}
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
}
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
stack[0] = from(tensor_res);
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
stack[0] = torch::stable::detail::from(tensor_res);
}
Tensor my_clone(Tensor t) {
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
}
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
stack[0] = from(tensor_res);
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
}
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
}
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_zero_(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_amax(Tensor t) {
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
}
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_amax_vec(Tensor t) {
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(to<bool>(stack[0]));
stack[0] = from(res);
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -478,6 +478,56 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_amax_vec", &boxed_my_amax_vec);
}
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
}
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
Tensor t1_1 = my_clone(t1);
Tensor t1_2 = my_clone(t1);
Tensor t2_1 = my_clone(t2);
Tensor t2_2 = my_clone(t2);
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
}
// Test functions for torch::stable::accelerator APIs
#ifdef LAE_USE_CUDA
@ -500,8 +550,8 @@ void boxed_test_device_guard(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
}
int64_t test_device_guard_set_index() {
@ -520,7 +570,7 @@ void boxed_test_device_guard_set_index(
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_device_guard_set_index();
stack[0] = from(res);
stack[0] = torch::stable::detail::from(res);
}
int64_t test_stream(int32_t device_index) {
@ -536,8 +586,8 @@ void boxed_test_stream(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
}
int64_t test_get_current_device_index() {
@ -549,7 +599,7 @@ void boxed_test_get_current_device_index(
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_get_current_device_index();
stack[0] = from(res);
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -565,4 +615,5 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_stream", &boxed_test_stream);
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
}
#endif // LAE_USE_CUDA

View File

@ -333,3 +333,45 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
Returns: New zeros tensor
"""
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
def my__foreach_mul_(tensors, others) -> ():
"""
Updates tensors to be the result of pointwise multiplying with others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: nothing, tensors is updated in place.
"""
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
def my__foreach_mul(tensors, others) -> list[Tensor]:
"""
Returns a list of tensors that are the results of pointwise multiplying
tensors and others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: list of multiplied tensors
"""
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
"""
Returns a list of 2 tensors corresponding to the square of the inputs.
Args:
t1: Tensor
t2: Tensor
Returns: list of [t1^2, t2^2]
"""
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
t1, t2
)

View File

@ -367,6 +367,57 @@ if not IS_WINDOWS:
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
self.assertEqual(result.stride(), expected.stride())
def test_my__foreach_mul_(self, device):
import libtorch_agnostic
N = 5
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
tensors_c = [t.clone() for t in tensors]
others = [torch.rand(32, 16, device=device) for _ in range(N)]
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
expected_values = torch._foreach_mul(tensors_c, others)
for tensor_t, expected_t in zip(tensors, expected_values):
self.assertEqual(tensor_t, expected_t)
def test_my__foreach_mul(self, device):
import libtorch_agnostic
N = 5
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
others = [torch.rand(32, 16, device=device) for _ in range(N)]
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
expected = torch._foreach_mul(tensors, others)
for result_t, expected_t in zip(result, expected):
self.assertEqual(result_t, expected_t)
def _make_cuda_tensors(prior_mem):
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
expected = torch._foreach_mul(tensors, others)
for result_t, expected_t in zip(cuda_res, expected):
self.assertEqual(result_t, expected_t)
if tensors[0].is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_make_tensor_clones_and_call_foreach(self, device):
import libtorch_agnostic
t1 = torch.rand(2, 5, device=device)
t2 = torch.rand(3, 4, device=device)
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
self.assertEqual(result[0], t1 * t1)
self.assertEqual(result[1], t2 * t2)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -1672,6 +1672,61 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation(self):
def gn(x, z):
out = x.sin()
z.append(out)
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
def fn(x):
z = []
out1, out2 = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
return out1, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
@unittest.expectedFailure
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation_hidden(self):
def gn(x, z):
out = x.sin()
z.append(out)
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
def fn(x):
z = []
out1 = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
return out1, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
devices = ["cuda", "hpu"]
instantiate_device_type_tests(

View File

@ -2155,6 +2155,43 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
torch.compile(model)
torch.compile(other_model)
def test_dynamo_disable_annotations(self):
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("buffer", torch.rand(2, 2))
@torch._dynamo.disable()
def f1(self, x) -> torch.Tensor:
return x + self.buffer + 1
@torch._dynamo.disable()
def f2(self, x) -> torch.Tensor:
return x + self.buffer + 2
def forward(self, x) -> torch.Tensor:
return self.f1(x) + self.f2(x)
model = SimpleModel()
inp = torch.rand(2, 2)
with torch.fx.traceback.preserve_node_meta():
exported_model = torch.export.export(model, (inp,))
graph = exported_model.graph_module.graph
found_f1 = False
found_f2 = False
for node in graph.nodes:
if "custom" in node.meta:
if "_torchdynamo_disable_method" in node.meta["custom"]:
if node.meta["custom"]["_torchdynamo_disable_method"] == "f1":
found_f1 = True
elif node.meta["custom"]["_torchdynamo_disable_method"] == "f2":
found_f2 = True
self.assertTrue(found_f1)
self.assertTrue(found_f2)
model.forward = torch._dynamo.disable(model.forward, recursive=False)
with self.assertRaises(RuntimeError):
exported_model = torch.export.export(model, (inp,))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Callable, Sequence
from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Callable, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import NamedTuple, Optional, TYPE_CHECKING
from typing import Callable, NamedTuple, Optional
import torch
import torch._dynamo
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -742,11 +742,14 @@ class TestExport(TestCase):
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
)
@requires_gpu

View File

@ -5222,6 +5222,7 @@ xfail_by_backend = {
"test_reentrant_with_callbacks_both_depths", # queue_callback
"test_reentrant_with_callbacks_depth_0", # queue_callback
"test_reentrant_with_callbacks_depth_1", # queue_callback
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
"test_post_accumulate_grad_hook_ordering", # accuracy error

View File

@ -148,6 +148,24 @@ class FxirTestCase(InductorTestCase):
args = [torch.randn(8, device=self.device) for _ in range(2)]
self._compile_and_check(torch.add, args)
def test_device_type(self):
"""
Test that we allocate on a device type instead of a specific index.
"""
# Pass in a tensor on an indexed device.
device_runtime = getattr(torch, self.device)
indexed_device = torch.device(self.device, device_runtime.current_device())
args = [torch.randn(8, device=indexed_device) for _ in range(2)]
(gm,) = self._compile_and_check(torch.add, args)
(empty_strided,) = gm.graph.find_nodes(
op="call_function", target=torch.empty_strided
)
# Check that the device of the output allocation is not indexed.
output_device = torch.device(empty_strided.kwargs["device"])
self.assertIs(output_device.index, None)
self.assertEqual(output_device.type, indexed_device.type)
def test_multiple_kernels(self):
def foo(x, y):
return x.sum() + y.sum()

View File

@ -7364,6 +7364,62 @@ for shape in [(1,), ()]:
):
checkpoint_sequential(modules_list, 3, a)
@skipIfTorchDynamo("GraphExecGroup does not support compile")
def test_checkpoint_graph_execution_group(self):
def run(use_graph_execution_group):
counter = [0]
def fn(x):
counter[0] += 1
y = x.sin().cos()
z = y.sin().cos()
return y, z
x = torch.randn(3, 3, requires_grad=True)
y, z = checkpoint(fn, x, use_reentrant=False)
group = torch.utils.checkpoint.GraphExecGroup()
ctx = contextlib.nullcontext()
if use_graph_execution_group:
ctx = group
with ctx:
(grad_y,) = torch.autograd.grad(
z, inputs=(y,), grad_outputs=(torch.ones(3, 3),)
)
(grad_x,) = torch.autograd.grad(
y,
inputs=(x,),
grad_outputs=(grad_y,),
)
if use_graph_execution_group:
self.assertEqual(counter[0], 2)
else:
self.assertEqual(counter[0], 3)
run(use_graph_execution_group=True)
run(use_graph_execution_group=False)
# Test the not actually disjoint case (using retain_graph=True since
# otherwise autograd itself will catch this)
def fn(x):
return x.sin().cos()
x = torch.randn(3, 3, requires_grad=True)
out = checkpoint(fn, x, use_reentrant=False)
with torch.utils.checkpoint.GraphExecGroup():
# Under this context, we will enforce that two backward are disjoint
# even if retain_graph=True.
out.sum().backward(retain_graph=True)
with self.assertRaisesRegex(
RuntimeError, "Performing two backward calls that overlap"
):
out.sum().backward()
def test_checkpoint_detects_non_determinism(self):
def save_3_tensors(x):
out = x.sin().exp()

View File

@ -1281,7 +1281,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(p, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.Point was not an allowed global by default"):
f"GLOBAL {__name__}.Point was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
with torch.serialization.safe_globals([Point]):
@ -1300,7 +1300,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(c, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
f"GLOBAL {__name__}.ClassThatUsesBuildInstruction was not an allowed global by default"):
torch.load(f, weights_only=True)
try:
with torch.serialization.safe_globals([ClassThatUsesBuildInstruction]):
@ -1330,7 +1330,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(obj, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"):
f"GLOBAL {__name__}.{obj_cls.__name__} was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
@ -4501,9 +4501,10 @@ class TestSerialization(TestCase, SerializationMixin):
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
if not materialize_fake:
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError
with self.assertRaisesRegex(
AttributeError,
"Can't (get|pickle) local object 'WeakValueDictionary.__init__.<locals>.remove'"
exc,
"Can't (get|pickle) local object (<function |')WeakValueDictionary.__init__.<locals>.remove"
):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from typing_extensions import assert_type
from typing import Union
from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor

View File

@ -69,6 +69,7 @@ from torch.types import (
Storage,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.checkpoint import GraphExecGroup
# This module is defined in torch/csrc/Module.cpp
@ -1491,6 +1492,8 @@ def _is_multithreading_enabled() -> _bool: ...
def _set_multithreading_enabled(enabled: _bool) -> None: ...
def _set_view_replay_enabled(enabled: _bool) -> None: ...
def _is_view_replay_enabled() -> _bool: ...
def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ...
def _get_graph_exec_group() -> GraphExecGroup | None: ...
def _enter_dual_level() -> _int: ...
def _exit_dual_level(level: _int) -> None: ...
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...

View File

@ -1,9 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload, Union
from typing import Any, Callable, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -78,7 +78,7 @@ from torch.export.dynamic_shapes import (
_RelaxedConstraint,
Constraint,
)
from torch.fx import GraphModule
from torch.fx import GraphModule, traceback as fx_traceback
from torch.fx.experimental._dynamism import (
clone_and_convert_to_meta,
track_dynamism_across_examples,
@ -1134,6 +1134,17 @@ class DisableContext(_TorchDynamoContext):
try:
_maybe_set_eval_frame(_callback_from_stance(self.callback))
try:
if torch.compiler.is_exporting():
with fx_traceback.annotate(
{
"_torchdynamo_disable": True,
"_torchdynamo_disable_recursive": True,
"_torchdynamo_disable_method": getattr(
fn, "__name__", type(fn).__name__
),
}
):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
finally:
set_eval_frame(None)

View File

@ -196,6 +196,10 @@ def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
# this function is in external_utils so that convert_frame doesn't skip it.
@functools.wraps(fn)
def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if torch.compiler.is_exporting():
raise RuntimeError(
"Non-recursive torch.compiler.disable is not supported with torch.export."
)
return fn(*args, **kwargs)
return nonrecursive_disable_wrapper

View File

@ -2141,9 +2141,10 @@ class GuardBuilder(GuardBuilderBase):
original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1])
if hasattr(value, "__metadata_guard__"):
verify_guard_fn_signature(value)
cls = type(value)
def metadata_checker(x: Any) -> bool:
return value.__metadata_guard__(
return cls.__metadata_guard__(
original_metadata, x.__tensor_flatten__()[1]
)

View File

@ -1169,7 +1169,7 @@ class VariableBuilder:
f"{sym_expr} is not a basic Symbol."
)
self.tx.output.tracked_fakes.append(TrackedFake(node, source, None))
return SymNodeVariable(sym_node_proxy, node)
return SymNodeVariable.create(self.tx, sym_node_proxy, node)
elif is_torch_sym(value):
# Note: this doesn't handle nested symints.
# For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo.
@ -2454,7 +2454,7 @@ class VariableBuilder:
sym_expr = wrapped_value.node.expr
assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol."
self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy
unspec_var = SymNodeVariable(proxy, wrapped_value, **options)
unspec_var = SymNodeVariable.create(self.tx, proxy, wrapped_value, **options)
self.tx.output.unspec_variable_map[self.name] = unspec_var
if not is_constant_source(self.get_source()):
@ -3002,7 +3002,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
tx.output.current_tracer.track_produced_symints(example_value, proxy)
set_example_value(proxy.node, example_value)
return SymNodeVariable(proxy, example_value, **options)
return SymNodeVariable.create(tx, proxy, example_value, **options)
elif (
isinstance(example_value, torch.Stream)
and proxy.node.target is get_external_object_by_index

View File

@ -182,9 +182,9 @@ its type to `common_constant_types`.
if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes.
return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
tx, name, args, kwargs
)
return variables.SymNodeVariable.create(
tx, self.as_proxy(), self.value
).call_method(tx, name, args, kwargs)
try:
const_args = [a.as_python_constant() for a in args]

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence, Sized
from collections.abc import Callable, Sequence
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -247,7 +247,7 @@ def _make_inlined(tx: "InstructionTranslator", f):
def _call_function_and_unflatten_output(
tx, fn, args, kwargs, flat_example_value, ret_spec
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
):
from .builder import wrap_fx_proxy
@ -263,6 +263,21 @@ def _call_function_and_unflatten_output(
example_value=flat_example_value,
)
# wrap_fx_proxy creates fresh variable trackers. However, the main program
# after the speculate subgraph can still use the original tensor vts that
# are still pointing to the nodes present in the subgraph. So, we reproxify
# the original tensor vts with the subgraph outputs. This way, whenever the
# outer graph uses an original vt, it uses the subgraph output.
if body_r is not None:
for orig_vt, subgraph_vt in zip(body_r.items, flat_variable.items):
if isinstance(
orig_vt, (variables.SymNodeVariable, variables.TensorVariable)
):
assert isinstance(
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
)
orig_vt.proxy = subgraph_vt.proxy
if ret_spec.masks_to_filter_const_values:
from torch._dynamo.external_utils import insert_const_values_with_mask
@ -572,6 +587,7 @@ def _call_while_loop(
{},
None,
body_treespec,
body_r,
)
@ -1535,6 +1551,7 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
{},
None,
true_spec,
true_r,
)
@ -1858,6 +1875,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
{},
None,
OutputSpec(xs_treespec),
None,
)
@ -2090,7 +2108,13 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
return _call_function_and_unflatten_output(
tx, torch.ops.higher_order.scan, p_args, {}, None, _combine_spec
tx,
torch.ops.higher_order.scan,
p_args,
{},
None,
_combine_spec,
None,
)
@ -2213,7 +2237,7 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
return _call_function_and_unflatten_output(
tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec
tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec, body_r
)
@ -2419,7 +2443,13 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
return _call_function_and_unflatten_output(
tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec
tx,
self.value,
tuple(p_args),
p_kwargs,
flat_example_value,
treespec,
body_r,
)
@ -2506,7 +2536,7 @@ class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable
body_r.as_proxy(),
)
return _call_function_and_unflatten_output(
tx, self.value, proxy_args, {}, example_value, treespec
tx, self.value, proxy_args, {}, example_value, treespec, body_r
)
@ -2601,7 +2631,7 @@ class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
return _call_function_and_unflatten_output(
tx, self.value, proxy_args, {}, example_value, treespec
tx, self.value, proxy_args, {}, example_value, treespec, body_r
)
@ -2674,7 +2704,7 @@ class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
return _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r
)
@ -2793,6 +2823,7 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
{},
flat_example_value,
ret_spec,
ret_val,
)
@ -2860,6 +2891,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
checkpoint_kwargs,
example_value,
out_spec,
_body_r,
)
@ -2913,6 +2945,7 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
{},
example_value,
out_spec,
_body_r,
)
@ -3652,7 +3685,13 @@ class BaseHOPVariable(WrapHigherOrderVariable):
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
return _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
tx,
self.value,
p_args,
p_kwargs,
flat_example_value,
treespec,
body_r,
)
@ -3768,6 +3807,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
p_kwargs,
flat_example_value,
treespec,
body_r,
)
@ -3991,7 +4031,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
# Step 5: Install local_map subgraph
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
out = _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r
)
# Step 6: Restore inputs and outputs to global shapes

View File

@ -1097,9 +1097,19 @@ def placeholder_naming_pass(
node.name = node.target = name_map[node.name]
if node.name in custom_meta:
if node.meta.get("custom") is None:
node.meta["custom"] = custom_meta[node.name]
node.meta["custom"] = {}
else:
assert node.meta["custom"] == custom_meta[node.name]
# Assert if any existing key has different value
for k, v in node.meta["custom"].items():
if (
k in custom_meta[node.name]
and v != custom_meta[node.name][k]
):
raise AssertionError(
f"Mismatch in custom metadata for key {k}. Value in "
f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}."
)
node.meta["custom"].update(custom_meta[node.name])
# if the constant obj is an input, we also need to update meta["val"]
# because this is created before the placeholder naming pass
if isinstance(node.meta["val"], CustomObjArgument):

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -678,6 +678,7 @@ class FxConverter:
assert name not in V.graph.removed_buffers
device = buffer.get_device()
assert device
dtype = buffer.get_dtype()
shape = self._generate_sym_nodes(buffer.get_size())
stride = self._generate_sym_nodes(buffer.get_stride())
@ -685,7 +686,7 @@ class FxConverter:
node = self.gm.graph.call_function(
torch.empty_strided,
args=(shape, stride),
kwargs={"dtype": dtype, "device": device},
kwargs={"dtype": dtype, "device": device.type},
)
assert name
node.name = name

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -195,10 +195,12 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
r"""Starting from a target node, trace back until we hit input or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
starting from getattr to the target node, for example::
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).

View File

@ -52,7 +52,26 @@ __all__ = [
"MemRecordsAcc",
]
from contextlib import ContextDecorator
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
# global python state - whether profiler is currently enabled
@ -725,7 +744,8 @@ class profile:
return all_function_events
class record_function(ContextDecorator):
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) {
HANDLE_TH_ERRORS
if (obj == Py_None) {
c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt);
} else {
Py_INCREF(obj);
c10::AutogradState::get_tls_state().set_graph_exec_group(
c10::SafePyObject(obj, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
const auto& group =
c10::AutogradState::get_tls_state().get_graph_exec_group();
if (group.has_value()) {
PyObject* obj = group->ptr(getPyInterpreter());
Py_INCREF(obj);
return obj;
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (c10::InferenceMode::is_enabled()) {
@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = {
castPyCFunctionWithKeywords(set_view_replay_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr},
{"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level",
castPyCFunctionWithKeywords(python_exit_dual_level),

View File

@ -4,12 +4,65 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/stable/library.h>
#include <torch/library.h>
#include <torch/csrc/shim_conversion_utils.h>
#include <torch/csrc/stable/c/shim.h>
AOTITorchError torch_new_list_reserve_size(size_t size, StableListHandle* ret) {
auto list_ptr = std::make_unique<std::vector<StableIValue>>();
list_ptr->reserve(size);
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret = list_pointer_to_list_handle(list_ptr.release()); });
}
AOTI_TORCH_EXPORT AOTITorchError
torch_list_size(StableListHandle list_handle, size_t* size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
*size = list->size();
});
}
AOTI_TORCH_EXPORT AOTITorchError torch_list_get_item(
StableListHandle list_handle,
size_t index,
StableIValue* element) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
*element = list->at(index);
});
}
AOTI_TORCH_EXPORT AOTITorchError torch_list_set_item(
StableListHandle list_handle,
size_t index,
StableIValue element) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
list->at(index) = element;
});
}
AOTITorchError torch_list_push_back(
StableListHandle list_handle,
StableIValue element) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
list->push_back(element);
});
}
AOTI_TORCH_EXPORT AOTITorchError
torch_delete_list(StableListHandle list_handle) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
std::vector<StableIValue>* list_ptr =
list_handle_to_list_pointer(list_handle);
delete list_ptr;
});
}
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue,
@ -71,6 +124,19 @@ static StableIValue from_ivalue(
from_ivalue(inner_type, ivalue, extension_build_version));
return torch::stable::detail::_from(sivp, extension_build_version);
}
case c10::TypeKind::ListType: {
auto inner_type = type->castRaw<c10::ListType>()->getElementType();
auto ivalue_list = ivalue.toList();
auto stableivalue_list = std::make_unique<std::vector<StableIValue>>();
stableivalue_list->reserve(ivalue_list.size());
for (const auto& elem : ivalue_list) {
stableivalue_list->emplace_back(
from_ivalue(inner_type, elem, extension_build_version));
}
return torch::stable::detail::_from(
list_pointer_to_list_handle(stableivalue_list.release()),
extension_build_version);
}
default: {
TORCH_CHECK(
false,
@ -145,6 +211,21 @@ static c10::IValue to_ivalue(
delete sivp;
return ival;
}
case c10::TypeKind::ListType: {
auto inner_type = type->castRaw<c10::ListType>()->getElementType();
auto list_handle = torch::stable::detail::_to<StableListHandle>(
stable_ivalue, extension_build_version);
std::vector<StableIValue>* stableivalue_list =
list_handle_to_list_pointer(list_handle);
auto ivalue_list = c10::impl::GenericList(inner_type);
ivalue_list.reserve(stableivalue_list->size());
for (const auto& elem : *stableivalue_list) {
ivalue_list.emplace_back(
to_ivalue(inner_type, elem, extension_build_version));
}
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
return ivalue_list;
}
default: {
TORCH_CHECK(
false,

View File

@ -0,0 +1,22 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <vector>
inline std::vector<StableIValue>* list_handle_to_list_pointer(
StableListHandle handle) {
return reinterpret_cast<std::vector<StableIValue>*>(handle);
}
inline StableListHandle list_pointer_to_list_handle(
std::vector<StableIValue>* list_ptr) {
return reinterpret_cast<StableListHandle>(list_ptr);
}
inline StableListHandle new_list_handle(std::vector<StableIValue>&& list) {
std::vector<StableIValue>* new_list = new std::vector<StableIValue>(list);
return list_pointer_to_list_handle(new_list);
}

View File

@ -37,6 +37,34 @@ AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version);
struct StableListOpaque;
using StableListHandle = StableListOpaque*;
// returns an owning reference of a StableList. callee is responsible for
// freeing memory.
AOTI_TORCH_EXPORT AOTITorchError
torch_new_list_reserve_size(size_t size, StableListHandle* ret);
AOTI_TORCH_EXPORT AOTITorchError
torch_list_size(StableListHandle list_handle, size_t* size);
AOTI_TORCH_EXPORT AOTITorchError torch_list_get_item(
StableListHandle list_handle,
size_t index,
StableIValue* element);
AOTI_TORCH_EXPORT AOTITorchError torch_list_set_item(
StableListHandle list_handle,
size_t index,
StableIValue element);
AOTI_TORCH_EXPORT AOTITorchError
torch_list_push_back(StableListHandle list_handle, StableIValue element);
// deletes the underlying list referenced by list_handle
AOTI_TORCH_EXPORT AOTITorchError
torch_delete_list(StableListHandle list_handle);
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
#ifdef __cplusplus

View File

@ -2,6 +2,7 @@
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/tensor_struct.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
@ -192,6 +193,46 @@ struct FromImpl<torch::stable::Tensor> {
}
};
// Specialization for torch::headeronly::HeaderOnlyArrayRef<T> => StableIValue
// Returns a new owning reference of the underlying list.
template <typename T>
struct FromImpl<torch::headeronly::HeaderOnlyArrayRef<T>> {
static StableIValue call(
const torch::headeronly::HeaderOnlyArrayRef<T>& val,
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
StableListHandle new_list_handle;
try {
TORCH_ERROR_CODE_CHECK(
torch_new_list_reserve_size(val.size(), &new_list_handle));
for (const auto& elem : val) {
TORCH_ERROR_CODE_CHECK(
torch_list_push_back(new_list_handle, from(elem)));
}
return from(new_list_handle);
} catch (const std::runtime_error& e) {
if (new_list_handle != nullptr) {
// clean up memory if an error was thrown
TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle));
}
throw;
}
}
};
// Specialization for std::vector<T> => StableIValue, which is implemented the
// same way as HeaderOnlyArrayRef<T> => StableIValue
// Returns a new owning reference of the underlying list.
template <typename T>
struct FromImpl<std::vector<T>> {
static StableIValue call(
const std::vector<T>& val,
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
return from<torch::headeronly::HeaderOnlyArrayRef<T>>(val);
}
};
// =============================================================================
// TO CONVERSIONS (StableIValue -> T)
// =============================================================================
@ -342,6 +383,38 @@ struct ToImpl<torch::stable::Tensor> {
}
};
// Specialization for StableIValue => std::vector<T>
// std::vector<T> should be represented as a StableListHandle
// filled with StableIValues
// The new std::vector steals ownership of the underlying elements
// and we free the underlying list referred by the input StableListHandle.
template <typename T>
struct ToImpl<std::vector<T>> {
static std::vector<T> call(
StableIValue val,
[[maybe_unused]] uint64_t extension_build_version,
[[maybe_unused]] bool is_internal) {
auto list_handle = to<StableListHandle>(val);
size_t size;
try {
TORCH_ERROR_CODE_CHECK(torch_list_size(list_handle, &size));
std::vector<T> result;
result.reserve(size);
for (size_t i = 0; i < size; i++) {
StableIValue element;
TORCH_ERROR_CODE_CHECK(torch_list_get_item(list_handle, i, &element));
result.push_back(to<T>(element));
}
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
return result;
} catch (const std::runtime_error& e) {
// clean up memory if an exception is thrown, and rethrow
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
throw;
}
}
};
// =============================================================================
// end to helpers for converting between StableIValue and T
// =============================================================================

View File

@ -1,8 +1,9 @@
import functools
import math
import operator
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -441,15 +441,11 @@ def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy:
keepdim = args_schema[3] if len(args_schema) > 3 else False
dims = _infer_reduction_dims(dim, input_strategy.ndim)
reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
reduction_linear = all(
all(not p.is_partial() for p in op_spec.output_spec.placements)
for op_spec in input_strategy.strategies
)
return common_reduction_strategy(
input_strategy,
reduce_dims,
keep_dim=cast(bool, keepdim),
reduction_linear=reduction_linear,
reduction_linear=True,
reduction_op=NormReduction(norm_type),
)
@ -472,14 +468,10 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
if not isinstance(op_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
reduce_dims = list(range(op_strategy.ndim))
reduction_linear = all(
all(not p.is_partial() for p in op_spec.output_spec.placements)
for op_spec in op_strategy.strategies
)
output_strategy = common_reduction_strategy(
op_strategy,
reduce_dims,
reduction_linear=reduction_linear,
reduction_linear=True,
reduction_op=NormReduction(norm_type),
)
output_tuple_strategy_children.append(output_strategy)

View File

@ -3,6 +3,7 @@
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <algorithm>
#include <array>
#include <cstddef>
#include <functional>

View File

@ -33,6 +33,7 @@ __all__ = [
"SelectiveCheckpointContext",
"create_selective_checkpoint_contexts",
"SAC_IGNORED_OPS",
"GraphExecGroup",
]
_DEFAULT_DETERMINISM_MODE = "default"
@ -1072,7 +1073,7 @@ class _StopRecomputationError(Exception):
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None:
def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]) -> None:
def pack_hook(x):
x = x.detach() if x.requires_grad else x
target_frame = target_frame_ref()
@ -1145,10 +1146,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
return holder
def unpack_hook(holder):
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
# First check if we're inside a GraphExecGroup context
gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
if gid is None:
# Fallback to using the current graph task id
gid = torch._C._current_graph_task_id()
if gid == -1:
# generate a temporary id if we trigger unpack outside of a backward call
gid = int(uuid.uuid4())
if not frame.is_recomputed[gid]:
ctx = frame.input_saver.grad_fn
@ -1168,10 +1173,17 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
_internal_assert(gid in holder.handles)
if holder.handles[gid] is None:
extra = ""
if torch._C._get_graph_exec_group() is not None:
extra = (
"Performing two backward calls that overlap (i.e. require the same "
"saved activation in order to compute gradients) is not allowed while "
"under the torch.utils.checkpoint.GraphExecGroup context. "
)
raise CheckpointError(
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
"so only once. Otherwise please open an issue with details on your use case."
f"unpacked once. {extra}If you are calling ctx.saved_tensors in backward, make sure "
"to do so only once. Otherwise please open an issue with details on your use case."
)
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
ret = frame.recomputed[gid][holder.handles[gid]]
@ -1594,6 +1606,40 @@ def _checkpoint_without_reentrant_generator(
return
class GraphExecGroup:
"""Any checkpointed regions encountered by backward under the same instance
of this context manager will trigger recompute at most once, even if
there are multiple calls to backward.
Backward calls under the same instance of this context manager must execute
over non-overlapping regions of the backward graph even if retain_graph=True.
In particular, any two backward call cannot use the same saved activation for
gradient computation.
.. note::
This context manager only affects checkpoint with use_reentrant=False, and
is a no-op otherwise.
"""
def __enter__(self) -> "GraphExecGroup":
if torch._C._get_graph_exec_group() is not None:
raise RuntimeError(
"GraphExecGroup contexts cannot be nested. "
f"Already inside group {torch._C._get_graph_exec_group()}"
)
torch._C._set_graph_exec_group(self)
return self
def __exit__(self, *args: object) -> None:
torch._C._set_graph_exec_group(None)
@classmethod
def _get_current_group(cls) -> Optional["GraphExecGroup"]:
# Private API to be used by utils like AC
return torch._C._get_graph_exec_group()
# Note: [compiled autograd and checkpoint unpack hook]
# When tracing via compiled autograd, this hook will be visible to the
# compiler if the forward of this checkpointed region ran in eager.