mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
18 Commits
viable/str
...
docs
| Author | SHA1 | Date | |
|---|---|---|---|
| c93b821875 | |||
| 4957ae5838 | |||
| 31d6d3ef5c | |||
| 2325c511e7 | |||
| d865156967 | |||
| fbc0bd2e90 | |||
| 70f5f55abf | |||
| 69ecb562e7 | |||
| 5062abe4e7 | |||
| c7007e7584 | |||
| 09705ca9b2 | |||
| ea6b0b5d0f | |||
| bbf852d87f | |||
| 6392b986e7 | |||
| 32d30d96cf | |||
| 46516efa85 | |||
| 84b2147b85 | |||
| 1727a71cb6 |
@ -36,11 +36,7 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
;;
|
||||
rocm*)
|
||||
BASE_TARGET=rocm
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -260,6 +260,12 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
|
||||
CUDA_VERSION=13.0.0
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -381,6 +387,7 @@ docker build \
|
||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||
--build-arg "HALIDE=${HALIDE}" \
|
||||
--build-arg "PALLAS=${PALLAS}" \
|
||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||
--build-arg "ACL=${ACL:-}" \
|
||||
|
||||
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
@ -0,0 +1 @@
|
||||
0.8.0
|
||||
40
.ci/docker/common/install_jax.sh
Executable file
40
.ci/docker/common/install_jax.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
# Get the pinned JAX version (same for all CUDA versions)
|
||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
||||
|
||||
function install_jax_12() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
||||
}
|
||||
|
||||
function install_jax_13() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
||||
}
|
||||
|
||||
# idiomatic parameter and option handling in sh
|
||||
while test $# -gt 0
|
||||
do
|
||||
case "$1" in
|
||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
||||
;;
|
||||
13.0|13.0.*) install_jax_13;
|
||||
;;
|
||||
*) echo "bad argument $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
@ -49,11 +49,7 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
fi
|
||||
BASE_TARGET=rocm
|
||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -87,11 +87,7 @@ case ${image} in
|
||||
MANY_LINUX_VERSION="2_28"
|
||||
DEVTOOLSET_VERSION="11"
|
||||
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
||||
;;
|
||||
manylinux2_28-builder:xpu)
|
||||
|
||||
@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt
|
||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||
|
||||
ARG PALLAS
|
||||
ARG CUDA_VERSION
|
||||
# Install JAX with CUDA support (for Pallas)
|
||||
COPY ./common/install_jax.sh install_jax.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
||||
|
||||
ARG ONNX
|
||||
# Install ONNX dependencies
|
||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||
|
||||
@ -824,6 +824,11 @@ test_inductor_halide() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_pallas() {
|
||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_triton_cpu() {
|
||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||
assert_git_not_dirty
|
||||
@ -1724,6 +1729,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
test_inductor_halide
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
||||
test_inductor_pallas
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
|
||||
1
.github/nitpicks.yml
vendored
1
.github/nitpicks.yml
vendored
@ -10,3 +10,4 @@
|
||||
pathFilter:
|
||||
- 'torch/csrc/inductor/aoti_torch/c/*'
|
||||
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
||||
- 'torch/csrc/stable/c/*'
|
||||
|
||||
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -67,6 +67,7 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
|
||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,6 +81,32 @@ jobs:
|
||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-build:
|
||||
name: inductor-pallas-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-test:
|
||||
name: inductor-pallas-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-pallas-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-triton-cpu-build:
|
||||
name: inductor-triton-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <optional>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -15,7 +17,8 @@ struct C10_API AutogradState {
|
||||
bool inference_mode,
|
||||
bool fw_grad_mode,
|
||||
bool multithreading_enabled)
|
||||
: grad_mode_(grad_mode),
|
||||
: graph_exec_group_(std::nullopt),
|
||||
grad_mode_(grad_mode),
|
||||
inference_mode_(inference_mode),
|
||||
fw_grad_mode_(fw_grad_mode),
|
||||
multithreading_enabled_(multithreading_enabled),
|
||||
@ -41,6 +44,10 @@ struct C10_API AutogradState {
|
||||
view_replay_enabled_ = view_replay_enabled;
|
||||
}
|
||||
|
||||
void set_graph_exec_group(std::optional<SafePyObject> group) {
|
||||
graph_exec_group_ = std::move(group);
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
@ -61,7 +68,12 @@ struct C10_API AutogradState {
|
||||
return view_replay_enabled_;
|
||||
}
|
||||
|
||||
const std::optional<SafePyObject>& get_graph_exec_group() const {
|
||||
return graph_exec_group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<SafePyObject> graph_exec_group_;
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
|
||||
@ -382,20 +382,6 @@ coverage_ignore_functions = [
|
||||
# torch.ao.quantization.backend_config.tensorrt
|
||||
"get_tensorrt_backend_config",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
# torch.ao.quantization.backend_config.utils
|
||||
"entry_to_pretty_str",
|
||||
"get_fused_module_classes",
|
||||
"get_fuser_method_mapping",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_module_to_qat_module",
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_qat_module_classes",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"pattern_to_human_readable",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
# torch.ao.quantization.backend_config.x86
|
||||
"get_x86_backend_config",
|
||||
# torch.ao.quantization.fuse_modules
|
||||
"fuse_known_modules",
|
||||
@ -426,25 +412,6 @@ coverage_ignore_functions = [
|
||||
"insert_observers_for_model",
|
||||
"prepare",
|
||||
"propagate_dtypes_for_known_nodes",
|
||||
# torch.ao.quantization.fx.utils
|
||||
"all_node_args_except_first",
|
||||
"all_node_args_have_no_tensors",
|
||||
"assert_and_get_unique_device",
|
||||
"collect_producer_nodes",
|
||||
"create_getattr_from_value",
|
||||
"create_node_from_old_node_preserve_meta",
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"get_qconv_prepack_op",
|
||||
"get_skipped_module_name_and_classes",
|
||||
"graph_module_from_producer_nodes",
|
||||
"maybe_get_next_module",
|
||||
"node_arg_is_bias",
|
||||
"node_arg_is_weight",
|
||||
"return_arg_list",
|
||||
# torch.ao.quantization.pt2e.graph_utils
|
||||
"bfs_trace_with_node_process",
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
@ -860,80 +827,10 @@ coverage_ignore_functions = [
|
||||
"get_latency_of_one_partition",
|
||||
"get_latency_of_partitioned_graph",
|
||||
"get_partition_to_latency_mapping",
|
||||
# torch.fx.experimental.proxy_tensor
|
||||
"decompose",
|
||||
"disable_autocast_cache",
|
||||
"disable_proxy_modes_tracing",
|
||||
"dispatch_trace",
|
||||
"extract_val",
|
||||
"fake_signature",
|
||||
"fetch_sym_proxy",
|
||||
"fetch_object_proxy",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_isolated_graphmodule",
|
||||
"get_proxy_slot",
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
"set_original_aten_op",
|
||||
"set_proxy_slot",
|
||||
"snapshot_fake",
|
||||
"thunkify",
|
||||
"track_tensor",
|
||||
"track_tensor_tree",
|
||||
"wrap_key",
|
||||
"wrapper_and_args_for_make_fx",
|
||||
# torch.fx.experimental.recording
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"shape_env_check_state_equal",
|
||||
# torch.fx.experimental.sym_node
|
||||
"ceil_impl",
|
||||
"floor_ceil_helper",
|
||||
"floor_impl",
|
||||
"method_to_operator",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
"sympy_is_channels_last_strides_3d",
|
||||
"sympy_is_channels_last_strides_generic",
|
||||
"sympy_is_contiguous",
|
||||
"sympy_is_contiguous_generic",
|
||||
"to_node",
|
||||
"wrap_node",
|
||||
"sym_sqrt",
|
||||
# torch.fx.experimental.symbolic_shapes
|
||||
"bind_symbols",
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"create_contiguous",
|
||||
"error",
|
||||
"eval_guards",
|
||||
"eval_is_non_overlapping_and_dense",
|
||||
"expect_true",
|
||||
"find_symbol_binding_fx_nodes",
|
||||
"free_symbols",
|
||||
"free_unbacked_symbols",
|
||||
"fx_placeholder_targets",
|
||||
"fx_placeholder_vals",
|
||||
"guard_bool",
|
||||
"guard_float",
|
||||
"guard_int",
|
||||
"guard_scalar",
|
||||
"has_hint",
|
||||
"has_symbolic_sizes_strides",
|
||||
"is_channels_last_contiguous_2d",
|
||||
"is_channels_last_contiguous_3d",
|
||||
"is_channels_last_strides_2d",
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
# torch.fx.experimental.unification.match
|
||||
"edge",
|
||||
@ -971,24 +868,6 @@ coverage_ignore_functions = [
|
||||
"reverse_dict",
|
||||
# torch.fx.experimental.unification.multipledispatch.variadic
|
||||
"isvariadic",
|
||||
# torch.fx.experimental.unification.unification_tools
|
||||
"assoc",
|
||||
"assoc_in",
|
||||
"dissoc",
|
||||
"first",
|
||||
"get_in",
|
||||
"getter",
|
||||
"groupby",
|
||||
"itemfilter",
|
||||
"itemmap",
|
||||
"keyfilter",
|
||||
"keymap",
|
||||
"merge",
|
||||
"merge_with",
|
||||
"update_in",
|
||||
"valfilter",
|
||||
"valmap",
|
||||
# torch.fx.experimental.unification.utils
|
||||
"freeze",
|
||||
"hashable",
|
||||
"raises",
|
||||
|
||||
@ -12,6 +12,37 @@ These APIs are experimental and subject to change without notice.
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.sym_node
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
is_channels_last_contiguous_2d
|
||||
is_channels_last_contiguous_3d
|
||||
is_channels_last_strides_2d
|
||||
is_channels_last_strides_3d
|
||||
is_contiguous
|
||||
is_non_overlapping_and_dense_indicator
|
||||
method_to_operator
|
||||
sympy_is_channels_last_contiguous_2d
|
||||
sympy_is_channels_last_contiguous_3d
|
||||
sympy_is_channels_last_strides_2d
|
||||
sympy_is_channels_last_strides_3d
|
||||
sympy_is_channels_last_strides_generic
|
||||
sympy_is_contiguous
|
||||
sympy_is_contiguous_generic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
@ -69,6 +100,25 @@ These APIs are experimental and subject to change without notice.
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
cast_symbool_to_symint_guardless
|
||||
create_contiguous
|
||||
error
|
||||
eval_guards
|
||||
eval_is_non_overlapping_and_dense
|
||||
find_symbol_binding_fx_nodes
|
||||
free_symbols
|
||||
free_unbacked_symbols
|
||||
fx_placeholder_targets
|
||||
fx_placeholder_vals
|
||||
guard_bool
|
||||
guard_float
|
||||
guard_int
|
||||
guard_scalar
|
||||
has_hint
|
||||
has_symbolic_sizes_strides
|
||||
is_nested_int
|
||||
is_symbol_binding_fx_node
|
||||
is_symbolic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.proxy_tensor
|
||||
@ -91,4 +141,46 @@ These APIs are experimental and subject to change without notice.
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
decompose
|
||||
disable_autocast_cache
|
||||
disable_proxy_modes_tracing
|
||||
extract_val
|
||||
fake_signature
|
||||
fetch_object_proxy
|
||||
fetch_sym_proxy
|
||||
has_proxy_slot
|
||||
is_sym_node
|
||||
maybe_handle_decomp
|
||||
proxy_call
|
||||
set_meta
|
||||
set_original_aten_op
|
||||
set_proxy_slot
|
||||
snapshot_fake
|
||||
```
|
||||
|
||||
## torch.fx.experimental.unification.unification_tools
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
assoc
|
||||
assoc_in
|
||||
dissoc
|
||||
first
|
||||
keyfilter
|
||||
keymap
|
||||
merge
|
||||
merge_with
|
||||
update_in
|
||||
valfilter
|
||||
valmap
|
||||
|
||||
@ -1134,7 +1134,6 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
.. py:module:: torch.fx.experimental.schema_type_annotation
|
||||
.. py:module:: torch.fx.experimental.sym_node
|
||||
.. py:module:: torch.fx.experimental.unification.core
|
||||
.. py:module:: torch.fx.experimental.unification.dispatch
|
||||
.. py:module:: torch.fx.experimental.unification.match
|
||||
@ -1144,7 +1143,6 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
|
||||
.. py:module:: torch.fx.experimental.unification.unification_tools
|
||||
.. py:module:: torch.fx.experimental.unification.utils
|
||||
.. py:module:: torch.fx.experimental.unification.variable
|
||||
.. py:module:: torch.fx.experimental.unify_refinements
|
||||
|
||||
@ -134,6 +134,23 @@ Quantization to work with this as well.
|
||||
ObservationType
|
||||
```
|
||||
|
||||
## torch.ao.quantization.backend_config.utils
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.backend_config.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
entry_to_pretty_str
|
||||
pattern_to_human_readable
|
||||
remove_boolean_dispatch_from_name
|
||||
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.custom_config
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
@ -154,6 +171,30 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||
StandaloneModuleConfigEntry
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.utils
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.fx.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
all_node_args_except_first
|
||||
all_node_args_have_no_tensors
|
||||
collect_producer_nodes
|
||||
create_getattr_from_value
|
||||
create_node_from_old_node_preserve_meta
|
||||
graph_module_from_producer_nodes
|
||||
maybe_get_next_module
|
||||
node_arg_is_bias
|
||||
node_arg_is_weight
|
||||
return_arg_list
|
||||
```
|
||||
|
||||
## torch.ao.quantization.quantizer
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = sgd_out_of_place(
|
||||
to<Tensor>(stack[0]),
|
||||
to<Tensor>(stack[1]),
|
||||
float(to<double>(stack[2])),
|
||||
to<double>(stack[3]),
|
||||
to<bool>(stack[4]));
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<Tensor>(stack[1]),
|
||||
float(torch::stable::detail::to<double>(stack[2])),
|
||||
torch::stable::detail::to<double>(stack[3]),
|
||||
torch::stable::detail::to<bool>(stack[4]));
|
||||
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = identity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
Tensor my_abs(Tensor t) {
|
||||
const auto num_args = 1;
|
||||
StableIValue stack[num_args];
|
||||
stack[0] = from(t);
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::abs", "", stack);
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
|
||||
|
||||
auto mf = aoti_torch_memory_format_contiguous_format();
|
||||
|
||||
stack[0] = from(t);
|
||||
stack[1] = from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = from(std::nullopt); // layout
|
||||
stack[3] = from(std::optional(device)); // device
|
||||
stack[4] = from(std::optional(false)); // pin_memory
|
||||
stack[5] = from(std::optional(mf)); // memory_format
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = torch::stable::detail::from(std::nullopt); // layout
|
||||
stack[3] = torch::stable::detail::from(std::optional(device)); // device
|
||||
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
|
||||
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
|
||||
|
||||
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
|
||||
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = from(res);
|
||||
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = from(t1);
|
||||
stack_exp[0] = torch::stable::detail::from(t1);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = from(t2);
|
||||
stack_neg[0] = torch::stable::detail::from(t2);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
StableIValue stack_is_leaf[1];
|
||||
stack_is_leaf[0] = from(t3);
|
||||
stack_is_leaf[0] = torch::stable::detail::from(t3);
|
||||
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
|
||||
|
||||
return std::make_tuple(
|
||||
to<Tensor>(stack_exp[0]),
|
||||
to<Tensor>(stack_neg[0]),
|
||||
to<bool>(stack_is_leaf[0]));
|
||||
torch::stable::detail::to<Tensor>(stack_exp[0]),
|
||||
torch::stable::detail::to<Tensor>(stack_neg[0]),
|
||||
torch::stable::detail::to<bool>(stack_is_leaf[0]));
|
||||
}
|
||||
|
||||
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
|
||||
stack[0] = from(std::get<0>(tuple));
|
||||
stack[1] = from(std::get<1>(tuple));
|
||||
stack[2] = from(std::get<2>(tuple));
|
||||
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
|
||||
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
|
||||
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor neg_exp(Tensor t) {
|
||||
StableIValue stack[1];
|
||||
stack[0] = from(t);
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack);
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor divide_neg_exp(Tensor t) {
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = from(t);
|
||||
stack_neg[0] = torch::stable::detail::from(t);
|
||||
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = from(t);
|
||||
stack_exp[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
|
||||
stack_div[0] = stack_neg[0];
|
||||
stack_div[1] = stack_exp[0];
|
||||
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
|
||||
return to<Tensor>(stack_div[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack_div[0]);
|
||||
}
|
||||
|
||||
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
bool res = is_contiguous(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
|
||||
}
|
||||
|
||||
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
|
||||
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
|
||||
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_empty_like(Tensor t) {
|
||||
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_empty_like(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool my_is_cpu(Tensor t) {
|
||||
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
|
||||
|
||||
|
||||
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_is_cpu(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor fill_infinity(Tensor t) {
|
||||
@ -296,8 +296,8 @@ void boxed_fill_infinity(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = fill_infinity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
@ -310,8 +310,8 @@ void boxed_my_pad(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_pad(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
|
||||
@ -323,11 +323,11 @@ void boxed_my_narrow(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_narrow(
|
||||
to<Tensor>(stack[0]),
|
||||
to<int64_t>(stack[1]),
|
||||
to<int64_t>(stack[2]),
|
||||
to<int64_t>(stack[3]));
|
||||
stack[0] = from(res);
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<int64_t>(stack[1]),
|
||||
torch::stable::detail::to<int64_t>(stack[2]),
|
||||
torch::stable::detail::to<int64_t>(stack[3]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
}
|
||||
|
||||
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
Tensor my_clone(Tensor t) {
|
||||
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
|
||||
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_zero_(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax(Tensor t) {
|
||||
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_default_constructor(to<bool>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -478,6 +478,56 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_amax_vec", &boxed_my_amax_vec);
|
||||
}
|
||||
|
||||
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
|
||||
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
|
||||
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
|
||||
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
|
||||
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
}
|
||||
|
||||
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
|
||||
// This function tests that my__foreach_mul can take in std::initializer_lists
|
||||
// in addition to std::vectors.
|
||||
Tensor t1_1 = my_clone(t1);
|
||||
Tensor t1_2 = my_clone(t1);
|
||||
Tensor t2_1 = my_clone(t2);
|
||||
Tensor t2_2 = my_clone(t2);
|
||||
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
|
||||
}
|
||||
|
||||
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
|
||||
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
|
||||
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
|
||||
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
|
||||
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
@ -500,8 +550,8 @@ void boxed_test_device_guard(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_device_guard_set_index() {
|
||||
@ -520,7 +570,7 @@ void boxed_test_device_guard_set_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_device_guard_set_index();
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_stream(int32_t device_index) {
|
||||
@ -536,8 +586,8 @@ void boxed_test_stream(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_get_current_device_index() {
|
||||
@ -549,7 +599,7 @@ void boxed_test_get_current_device_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_get_current_device_index();
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -565,4 +615,5 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_stream", &boxed_test_stream);
|
||||
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
|
||||
}
|
||||
|
||||
#endif // LAE_USE_CUDA
|
||||
|
||||
@ -333,3 +333,45 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
|
||||
Returns: New zeros tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
|
||||
|
||||
|
||||
def my__foreach_mul_(tensors, others) -> ():
|
||||
"""
|
||||
Updates tensors to be the result of pointwise multiplying with others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: nothing, tensors is updated in place.
|
||||
"""
|
||||
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
|
||||
|
||||
|
||||
def my__foreach_mul(tensors, others) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of tensors that are the results of pointwise multiplying
|
||||
tensors and others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: list of multiplied tensors
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
|
||||
|
||||
|
||||
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of 2 tensors corresponding to the square of the inputs.
|
||||
|
||||
Args:
|
||||
t1: Tensor
|
||||
t2: Tensor
|
||||
|
||||
Returns: list of [t1^2, t2^2]
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
|
||||
t1, t2
|
||||
)
|
||||
|
||||
@ -367,6 +367,57 @@ if not IS_WINDOWS:
|
||||
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
|
||||
self.assertEqual(result.stride(), expected.stride())
|
||||
|
||||
def test_my__foreach_mul_(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
tensors_c = [t.clone() for t in tensors]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
|
||||
expected_values = torch._foreach_mul(tensors_c, others)
|
||||
|
||||
for tensor_t, expected_t in zip(tensors, expected_values):
|
||||
self.assertEqual(tensor_t, expected_t)
|
||||
|
||||
def test_my__foreach_mul(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
|
||||
for result_t, expected_t in zip(result, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
for result_t, expected_t in zip(cuda_res, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
if tensors[0].is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_make_tensor_clones_and_call_foreach(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t1 = torch.rand(2, 5, device=device)
|
||||
t2 = torch.rand(3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
|
||||
self.assertEqual(result[0], t1 * t1)
|
||||
self.assertEqual(result[1], t2 * t2)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1672,6 +1672,61 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
# The mutation is not reapplied in the backward because the flag was on.
|
||||
self.assertEqual(counter, 1)
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_list_mutation(self):
|
||||
def gn(x, z):
|
||||
out = x.sin()
|
||||
z.append(out)
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
|
||||
|
||||
def fn(x):
|
||||
z = []
|
||||
|
||||
out1, out2 = torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
z,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
return out1, z[0]
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_list_mutation_hidden(self):
|
||||
def gn(x, z):
|
||||
out = x.sin()
|
||||
z.append(out)
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
|
||||
def fn(x):
|
||||
z = []
|
||||
|
||||
out1 = torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
z,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
return out1, z[0]
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(
|
||||
|
||||
@ -2155,6 +2155,43 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
|
||||
torch.compile(model)
|
||||
torch.compile(other_model)
|
||||
|
||||
def test_dynamo_disable_annotations(self):
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer("buffer", torch.rand(2, 2))
|
||||
|
||||
@torch._dynamo.disable()
|
||||
def f1(self, x) -> torch.Tensor:
|
||||
return x + self.buffer + 1
|
||||
|
||||
@torch._dynamo.disable()
|
||||
def f2(self, x) -> torch.Tensor:
|
||||
return x + self.buffer + 2
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
return self.f1(x) + self.f2(x)
|
||||
|
||||
model = SimpleModel()
|
||||
inp = torch.rand(2, 2)
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
exported_model = torch.export.export(model, (inp,))
|
||||
graph = exported_model.graph_module.graph
|
||||
found_f1 = False
|
||||
found_f2 = False
|
||||
for node in graph.nodes:
|
||||
if "custom" in node.meta:
|
||||
if "_torchdynamo_disable_method" in node.meta["custom"]:
|
||||
if node.meta["custom"]["_torchdynamo_disable_method"] == "f1":
|
||||
found_f1 = True
|
||||
elif node.meta["custom"]["_torchdynamo_disable_method"] == "f2":
|
||||
found_f2 = True
|
||||
self.assertTrue(found_f1)
|
||||
self.assertTrue(found_f2)
|
||||
model.forward = torch._dynamo.disable(model.forward, recursive=False)
|
||||
with self.assertRaises(RuntimeError):
|
||||
exported_model = torch.export.export(model, (inp,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import NamedTuple, Optional, TYPE_CHECKING
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@zdevito. It represents a rather challenging test case for TorchDynamo
|
||||
|
||||
@ -742,11 +742,14 @@ class TestExport(TestCase):
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 'cat', {'moo': 0})
|
||||
('call_function', 'item', {'moo': 0})
|
||||
('call_function', 'ge_1', {'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'moo': 0})
|
||||
('call_function', 'mul', {'moo': 0})""",
|
||||
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
|
||||
@ -5222,6 +5222,7 @@ xfail_by_backend = {
|
||||
"test_reentrant_with_callbacks_both_depths", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_0", # queue_callback
|
||||
"test_reentrant_with_callbacks_depth_1", # queue_callback
|
||||
"test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped
|
||||
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
|
||||
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
|
||||
"test_post_accumulate_grad_hook_ordering", # accuracy error
|
||||
|
||||
@ -148,6 +148,24 @@ class FxirTestCase(InductorTestCase):
|
||||
args = [torch.randn(8, device=self.device) for _ in range(2)]
|
||||
self._compile_and_check(torch.add, args)
|
||||
|
||||
def test_device_type(self):
|
||||
"""
|
||||
Test that we allocate on a device type instead of a specific index.
|
||||
"""
|
||||
# Pass in a tensor on an indexed device.
|
||||
device_runtime = getattr(torch, self.device)
|
||||
indexed_device = torch.device(self.device, device_runtime.current_device())
|
||||
args = [torch.randn(8, device=indexed_device) for _ in range(2)]
|
||||
(gm,) = self._compile_and_check(torch.add, args)
|
||||
(empty_strided,) = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.empty_strided
|
||||
)
|
||||
|
||||
# Check that the device of the output allocation is not indexed.
|
||||
output_device = torch.device(empty_strided.kwargs["device"])
|
||||
self.assertIs(output_device.index, None)
|
||||
self.assertEqual(output_device.type, indexed_device.type)
|
||||
|
||||
def test_multiple_kernels(self):
|
||||
def foo(x, y):
|
||||
return x.sum() + y.sum()
|
||||
|
||||
@ -7364,6 +7364,62 @@ for shape in [(1,), ()]:
|
||||
):
|
||||
checkpoint_sequential(modules_list, 3, a)
|
||||
|
||||
@skipIfTorchDynamo("GraphExecGroup does not support compile")
|
||||
def test_checkpoint_graph_execution_group(self):
|
||||
def run(use_graph_execution_group):
|
||||
counter = [0]
|
||||
|
||||
def fn(x):
|
||||
counter[0] += 1
|
||||
y = x.sin().cos()
|
||||
z = y.sin().cos()
|
||||
return y, z
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
|
||||
y, z = checkpoint(fn, x, use_reentrant=False)
|
||||
|
||||
group = torch.utils.checkpoint.GraphExecGroup()
|
||||
|
||||
ctx = contextlib.nullcontext()
|
||||
if use_graph_execution_group:
|
||||
ctx = group
|
||||
|
||||
with ctx:
|
||||
(grad_y,) = torch.autograd.grad(
|
||||
z, inputs=(y,), grad_outputs=(torch.ones(3, 3),)
|
||||
)
|
||||
|
||||
(grad_x,) = torch.autograd.grad(
|
||||
y,
|
||||
inputs=(x,),
|
||||
grad_outputs=(grad_y,),
|
||||
)
|
||||
|
||||
if use_graph_execution_group:
|
||||
self.assertEqual(counter[0], 2)
|
||||
else:
|
||||
self.assertEqual(counter[0], 3)
|
||||
|
||||
run(use_graph_execution_group=True)
|
||||
run(use_graph_execution_group=False)
|
||||
|
||||
# Test the not actually disjoint case (using retain_graph=True since
|
||||
# otherwise autograd itself will catch this)
|
||||
def fn(x):
|
||||
return x.sin().cos()
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
out = checkpoint(fn, x, use_reentrant=False)
|
||||
with torch.utils.checkpoint.GraphExecGroup():
|
||||
# Under this context, we will enforce that two backward are disjoint
|
||||
# even if retain_graph=True.
|
||||
out.sum().backward(retain_graph=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Performing two backward calls that overlap"
|
||||
):
|
||||
out.sum().backward()
|
||||
|
||||
def test_checkpoint_detects_non_determinism(self):
|
||||
def save_3_tensors(x):
|
||||
out = x.sin().exp()
|
||||
|
||||
@ -1281,7 +1281,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.save(p, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"GLOBAL __main__.Point was not an allowed global by default"):
|
||||
f"GLOBAL {__name__}.Point was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
f.seek(0)
|
||||
with torch.serialization.safe_globals([Point]):
|
||||
@ -1300,7 +1300,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.save(c, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
|
||||
f"GLOBAL {__name__}.ClassThatUsesBuildInstruction was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
try:
|
||||
with torch.serialization.safe_globals([ClassThatUsesBuildInstruction]):
|
||||
@ -1330,7 +1330,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.save(obj, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"):
|
||||
f"GLOBAL {__name__}.{obj_cls.__name__} was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
f.seek(0)
|
||||
@ -4501,9 +4501,10 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
|
||||
if not materialize_fake:
|
||||
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
|
||||
exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError,
|
||||
"Can't (get|pickle) local object 'WeakValueDictionary.__init__.<locals>.remove'"
|
||||
exc,
|
||||
"Can't (get|pickle) local object (<function |')WeakValueDictionary.__init__.<locals>.remove"
|
||||
):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(ft, f)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import TypeAlias, Union
|
||||
from typing_extensions import assert_type
|
||||
from typing import Union
|
||||
from typing_extensions import assert_type, TypeAlias
|
||||
|
||||
from torch import randn, Tensor
|
||||
|
||||
|
||||
2
third_party/kineto
vendored
2
third_party/kineto
vendored
Submodule third_party/kineto updated: 6fcbc53d33...57c561f4ca
@ -69,6 +69,7 @@ from torch.types import (
|
||||
Storage,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.checkpoint import GraphExecGroup
|
||||
|
||||
# This module is defined in torch/csrc/Module.cpp
|
||||
|
||||
@ -1491,6 +1492,8 @@ def _is_multithreading_enabled() -> _bool: ...
|
||||
def _set_multithreading_enabled(enabled: _bool) -> None: ...
|
||||
def _set_view_replay_enabled(enabled: _bool) -> None: ...
|
||||
def _is_view_replay_enabled() -> _bool: ...
|
||||
def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ...
|
||||
def _get_graph_exec_group() -> GraphExecGroup | None: ...
|
||||
def _enter_dual_level() -> _int: ...
|
||||
def _exit_dual_level(level: _int) -> None: ...
|
||||
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, overload, Union
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@ -78,7 +78,7 @@ from torch.export.dynamic_shapes import (
|
||||
_RelaxedConstraint,
|
||||
Constraint,
|
||||
)
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import GraphModule, traceback as fx_traceback
|
||||
from torch.fx.experimental._dynamism import (
|
||||
clone_and_convert_to_meta,
|
||||
track_dynamism_across_examples,
|
||||
@ -1134,6 +1134,17 @@ class DisableContext(_TorchDynamoContext):
|
||||
try:
|
||||
_maybe_set_eval_frame(_callback_from_stance(self.callback))
|
||||
try:
|
||||
if torch.compiler.is_exporting():
|
||||
with fx_traceback.annotate(
|
||||
{
|
||||
"_torchdynamo_disable": True,
|
||||
"_torchdynamo_disable_recursive": True,
|
||||
"_torchdynamo_disable_method": getattr(
|
||||
fn, "__name__", type(fn).__name__
|
||||
),
|
||||
}
|
||||
):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
set_eval_frame(None)
|
||||
|
||||
@ -196,6 +196,10 @@ def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
# this function is in external_utils so that convert_frame doesn't skip it.
|
||||
@functools.wraps(fn)
|
||||
def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if torch.compiler.is_exporting():
|
||||
raise RuntimeError(
|
||||
"Non-recursive torch.compiler.disable is not supported with torch.export."
|
||||
)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return nonrecursive_disable_wrapper
|
||||
|
||||
@ -2141,9 +2141,10 @@ class GuardBuilder(GuardBuilderBase):
|
||||
original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1])
|
||||
if hasattr(value, "__metadata_guard__"):
|
||||
verify_guard_fn_signature(value)
|
||||
cls = type(value)
|
||||
|
||||
def metadata_checker(x: Any) -> bool:
|
||||
return value.__metadata_guard__(
|
||||
return cls.__metadata_guard__(
|
||||
original_metadata, x.__tensor_flatten__()[1]
|
||||
)
|
||||
|
||||
|
||||
@ -1169,7 +1169,7 @@ class VariableBuilder:
|
||||
f"{sym_expr} is not a basic Symbol."
|
||||
)
|
||||
self.tx.output.tracked_fakes.append(TrackedFake(node, source, None))
|
||||
return SymNodeVariable(sym_node_proxy, node)
|
||||
return SymNodeVariable.create(self.tx, sym_node_proxy, node)
|
||||
elif is_torch_sym(value):
|
||||
# Note: this doesn't handle nested symints.
|
||||
# For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo.
|
||||
@ -2454,7 +2454,7 @@ class VariableBuilder:
|
||||
sym_expr = wrapped_value.node.expr
|
||||
assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol."
|
||||
self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy
|
||||
unspec_var = SymNodeVariable(proxy, wrapped_value, **options)
|
||||
unspec_var = SymNodeVariable.create(self.tx, proxy, wrapped_value, **options)
|
||||
self.tx.output.unspec_variable_map[self.name] = unspec_var
|
||||
|
||||
if not is_constant_source(self.get_source()):
|
||||
@ -3002,7 +3002,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
tx.output.current_tracer.track_produced_symints(example_value, proxy)
|
||||
set_example_value(proxy.node, example_value)
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
return SymNodeVariable.create(tx, proxy, example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, torch.Stream)
|
||||
and proxy.node.target is get_external_object_by_index
|
||||
|
||||
@ -182,9 +182,9 @@ its type to `common_constant_types`.
|
||||
|
||||
if any(isinstance(x, SymNodeVariable) for x in args):
|
||||
# Promote to SymNodeVariable for operations involving dynamic shapes.
|
||||
return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
|
||||
tx, name, args, kwargs
|
||||
)
|
||||
return variables.SymNodeVariable.create(
|
||||
tx, self.as_proxy(), self.value
|
||||
).call_method(tx, name, args, kwargs)
|
||||
|
||||
try:
|
||||
const_args = [a.as_python_constant() for a in args]
|
||||
|
||||
@ -21,9 +21,9 @@ restoring state changes.
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence, Sized
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
|
||||
@ -247,7 +247,7 @@ def _make_inlined(tx: "InstructionTranslator", f):
|
||||
|
||||
|
||||
def _call_function_and_unflatten_output(
|
||||
tx, fn, args, kwargs, flat_example_value, ret_spec
|
||||
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
|
||||
):
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
@ -263,6 +263,21 @@ def _call_function_and_unflatten_output(
|
||||
example_value=flat_example_value,
|
||||
)
|
||||
|
||||
# wrap_fx_proxy creates fresh variable trackers. However, the main program
|
||||
# after the speculate subgraph can still use the original tensor vts that
|
||||
# are still pointing to the nodes present in the subgraph. So, we reproxify
|
||||
# the original tensor vts with the subgraph outputs. This way, whenever the
|
||||
# outer graph uses an original vt, it uses the subgraph output.
|
||||
if body_r is not None:
|
||||
for orig_vt, subgraph_vt in zip(body_r.items, flat_variable.items):
|
||||
if isinstance(
|
||||
orig_vt, (variables.SymNodeVariable, variables.TensorVariable)
|
||||
):
|
||||
assert isinstance(
|
||||
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
|
||||
)
|
||||
orig_vt.proxy = subgraph_vt.proxy
|
||||
|
||||
if ret_spec.masks_to_filter_const_values:
|
||||
from torch._dynamo.external_utils import insert_const_values_with_mask
|
||||
|
||||
@ -572,6 +587,7 @@ def _call_while_loop(
|
||||
{},
|
||||
None,
|
||||
body_treespec,
|
||||
body_r,
|
||||
)
|
||||
|
||||
|
||||
@ -1535,6 +1551,7 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
{},
|
||||
None,
|
||||
true_spec,
|
||||
true_r,
|
||||
)
|
||||
|
||||
|
||||
@ -1858,6 +1875,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
{},
|
||||
None,
|
||||
OutputSpec(xs_treespec),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -2090,7 +2108,13 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, torch.ops.higher_order.scan, p_args, {}, None, _combine_spec
|
||||
tx,
|
||||
torch.ops.higher_order.scan,
|
||||
p_args,
|
||||
{},
|
||||
None,
|
||||
_combine_spec,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -2213,7 +2237,7 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec
|
||||
tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec, body_r
|
||||
)
|
||||
|
||||
|
||||
@ -2419,7 +2443,13 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec
|
||||
tx,
|
||||
self.value,
|
||||
tuple(p_args),
|
||||
p_kwargs,
|
||||
flat_example_value,
|
||||
treespec,
|
||||
body_r,
|
||||
)
|
||||
|
||||
|
||||
@ -2506,7 +2536,7 @@ class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable
|
||||
body_r.as_proxy(),
|
||||
)
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, proxy_args, {}, example_value, treespec
|
||||
tx, self.value, proxy_args, {}, example_value, treespec, body_r
|
||||
)
|
||||
|
||||
|
||||
@ -2601,7 +2631,7 @@ class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, proxy_args, {}, example_value, treespec
|
||||
tx, self.value, proxy_args, {}, example_value, treespec, body_r
|
||||
)
|
||||
|
||||
|
||||
@ -2674,7 +2704,7 @@ class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r
|
||||
)
|
||||
|
||||
|
||||
@ -2793,6 +2823,7 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
{},
|
||||
flat_example_value,
|
||||
ret_spec,
|
||||
ret_val,
|
||||
)
|
||||
|
||||
|
||||
@ -2860,6 +2891,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
||||
checkpoint_kwargs,
|
||||
example_value,
|
||||
out_spec,
|
||||
_body_r,
|
||||
)
|
||||
|
||||
|
||||
@ -2913,6 +2945,7 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
|
||||
{},
|
||||
example_value,
|
||||
out_spec,
|
||||
_body_r,
|
||||
)
|
||||
|
||||
|
||||
@ -3652,7 +3685,13 @@ class BaseHOPVariable(WrapHigherOrderVariable):
|
||||
|
||||
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
|
||||
tx,
|
||||
self.value,
|
||||
p_args,
|
||||
p_kwargs,
|
||||
flat_example_value,
|
||||
treespec,
|
||||
body_r,
|
||||
)
|
||||
|
||||
|
||||
@ -3768,6 +3807,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
p_kwargs,
|
||||
flat_example_value,
|
||||
treespec,
|
||||
body_r,
|
||||
)
|
||||
|
||||
|
||||
@ -3991,7 +4031,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
||||
# Step 5: Install local_map subgraph
|
||||
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
|
||||
out = _call_function_and_unflatten_output(
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r
|
||||
)
|
||||
|
||||
# Step 6: Restore inputs and outputs to global shapes
|
||||
|
||||
@ -1097,9 +1097,19 @@ def placeholder_naming_pass(
|
||||
node.name = node.target = name_map[node.name]
|
||||
if node.name in custom_meta:
|
||||
if node.meta.get("custom") is None:
|
||||
node.meta["custom"] = custom_meta[node.name]
|
||||
node.meta["custom"] = {}
|
||||
else:
|
||||
assert node.meta["custom"] == custom_meta[node.name]
|
||||
# Assert if any existing key has different value
|
||||
for k, v in node.meta["custom"].items():
|
||||
if (
|
||||
k in custom_meta[node.name]
|
||||
and v != custom_meta[node.name][k]
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Mismatch in custom metadata for key {k}. Value in "
|
||||
f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}."
|
||||
)
|
||||
node.meta["custom"].update(custom_meta[node.name])
|
||||
# if the constant obj is an input, we also need to update meta["val"]
|
||||
# because this is created before the placeholder naming pass
|
||||
if isinstance(node.meta["val"], CustomObjArgument):
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
import sympy # noqa: TC002
|
||||
|
||||
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from ..ir import IRNode
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
@ -678,6 +678,7 @@ class FxConverter:
|
||||
assert name not in V.graph.removed_buffers
|
||||
|
||||
device = buffer.get_device()
|
||||
assert device
|
||||
dtype = buffer.get_dtype()
|
||||
shape = self._generate_sym_nodes(buffer.get_size())
|
||||
stride = self._generate_sym_nodes(buffer.get_stride())
|
||||
@ -685,7 +686,7 @@ class FxConverter:
|
||||
node = self.gm.graph.call_function(
|
||||
torch.empty_strided,
|
||||
args=(shape, stride),
|
||||
kwargs={"dtype": dtype, "device": device},
|
||||
kwargs={"dtype": dtype, "device": device.type},
|
||||
)
|
||||
assert name
|
||||
node.name = name
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, partial
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
@ -195,10 +195,12 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
||||
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
|
||||
r"""Starting from a target node, trace back until we hit input or
|
||||
getattr node. This is used to extract the chain of operators
|
||||
starting from getattr to the target node, for example
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
starting from getattr to the target node, for example::
|
||||
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
|
||||
collect_producer_nodes(observed) will either return a list of nodes that
|
||||
produces the observed node or None if we can't extract a self contained
|
||||
graph without free variables(inputs of the forward function).
|
||||
|
||||
@ -52,7 +52,26 @@ __all__ = [
|
||||
"MemRecordsAcc",
|
||||
]
|
||||
|
||||
from contextlib import ContextDecorator
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
from contextlib import ContextDecorator as _ContextDecorator
|
||||
except ImportError:
|
||||
import functools
|
||||
|
||||
class _ContextDecorator: # type: ignore[no-redef]
|
||||
def __enter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# global python state - whether profiler is currently enabled
|
||||
@ -725,7 +744,8 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
class record_function(ContextDecorator):
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
||||
|
||||
@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (obj == Py_None) {
|
||||
c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt);
|
||||
} else {
|
||||
Py_INCREF(obj);
|
||||
c10::AutogradState::get_tls_state().set_graph_exec_group(
|
||||
c10::SafePyObject(obj, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& group =
|
||||
c10::AutogradState::get_tls_state().get_graph_exec_group();
|
||||
if (group.has_value()) {
|
||||
PyObject* obj = group->ptr(getPyInterpreter());
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
} else {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (c10::InferenceMode::is_enabled()) {
|
||||
@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = {
|
||||
castPyCFunctionWithKeywords(set_view_replay_enabled),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr},
|
||||
{"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr},
|
||||
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
|
||||
{"_exit_dual_level",
|
||||
castPyCFunctionWithKeywords(python_exit_dual_level),
|
||||
|
||||
@ -4,12 +4,65 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <torch/csrc/shim_conversion_utils.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
|
||||
AOTITorchError torch_new_list_reserve_size(size_t size, StableListHandle* ret) {
|
||||
auto list_ptr = std::make_unique<std::vector<StableIValue>>();
|
||||
list_ptr->reserve(size);
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ *ret = list_pointer_to_list_handle(list_ptr.release()); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
torch_list_size(StableListHandle list_handle, size_t* size) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
|
||||
*size = list->size();
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_list_get_item(
|
||||
StableListHandle list_handle,
|
||||
size_t index,
|
||||
StableIValue* element) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
|
||||
*element = list->at(index);
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_list_set_item(
|
||||
StableListHandle list_handle,
|
||||
size_t index,
|
||||
StableIValue element) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
|
||||
list->at(index) = element;
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError torch_list_push_back(
|
||||
StableListHandle list_handle,
|
||||
StableIValue element) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
std::vector<StableIValue>* list = list_handle_to_list_pointer(list_handle);
|
||||
list->push_back(element);
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
torch_delete_list(StableListHandle list_handle) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
std::vector<StableIValue>* list_ptr =
|
||||
list_handle_to_list_pointer(list_handle);
|
||||
delete list_ptr;
|
||||
});
|
||||
}
|
||||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue,
|
||||
@ -71,6 +124,19 @@ static StableIValue from_ivalue(
|
||||
from_ivalue(inner_type, ivalue, extension_build_version));
|
||||
return torch::stable::detail::_from(sivp, extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::ListType: {
|
||||
auto inner_type = type->castRaw<c10::ListType>()->getElementType();
|
||||
auto ivalue_list = ivalue.toList();
|
||||
auto stableivalue_list = std::make_unique<std::vector<StableIValue>>();
|
||||
stableivalue_list->reserve(ivalue_list.size());
|
||||
for (const auto& elem : ivalue_list) {
|
||||
stableivalue_list->emplace_back(
|
||||
from_ivalue(inner_type, elem, extension_build_version));
|
||||
}
|
||||
return torch::stable::detail::_from(
|
||||
list_pointer_to_list_handle(stableivalue_list.release()),
|
||||
extension_build_version);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -145,6 +211,21 @@ static c10::IValue to_ivalue(
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
case c10::TypeKind::ListType: {
|
||||
auto inner_type = type->castRaw<c10::ListType>()->getElementType();
|
||||
auto list_handle = torch::stable::detail::_to<StableListHandle>(
|
||||
stable_ivalue, extension_build_version);
|
||||
std::vector<StableIValue>* stableivalue_list =
|
||||
list_handle_to_list_pointer(list_handle);
|
||||
auto ivalue_list = c10::impl::GenericList(inner_type);
|
||||
ivalue_list.reserve(stableivalue_list->size());
|
||||
for (const auto& elem : *stableivalue_list) {
|
||||
ivalue_list.emplace_back(
|
||||
to_ivalue(inner_type, elem, extension_build_version));
|
||||
}
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
|
||||
return ivalue_list;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
||||
22
torch/csrc/shim_conversion_utils.h
Normal file
22
torch/csrc/shim_conversion_utils.h
Normal file
@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
inline std::vector<StableIValue>* list_handle_to_list_pointer(
|
||||
StableListHandle handle) {
|
||||
return reinterpret_cast<std::vector<StableIValue>*>(handle);
|
||||
}
|
||||
|
||||
inline StableListHandle list_pointer_to_list_handle(
|
||||
std::vector<StableIValue>* list_ptr) {
|
||||
return reinterpret_cast<StableListHandle>(list_ptr);
|
||||
}
|
||||
|
||||
inline StableListHandle new_list_handle(std::vector<StableIValue>&& list) {
|
||||
std::vector<StableIValue>* new_list = new std::vector<StableIValue>(list);
|
||||
return list_pointer_to_list_handle(new_list);
|
||||
}
|
||||
@ -37,6 +37,34 @@ AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||
uint64_t extension_build_version);
|
||||
|
||||
struct StableListOpaque;
|
||||
using StableListHandle = StableListOpaque*;
|
||||
|
||||
// returns an owning reference of a StableList. callee is responsible for
|
||||
// freeing memory.
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
torch_new_list_reserve_size(size_t size, StableListHandle* ret);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
torch_list_size(StableListHandle list_handle, size_t* size);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_list_get_item(
|
||||
StableListHandle list_handle,
|
||||
size_t index,
|
||||
StableIValue* element);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_list_set_item(
|
||||
StableListHandle list_handle,
|
||||
size_t index,
|
||||
StableIValue element);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
torch_list_push_back(StableListHandle list_handle, StableIValue element);
|
||||
|
||||
// deletes the underlying list referenced by list_handle
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
torch_delete_list(StableListHandle list_handle);
|
||||
|
||||
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/tensor_struct.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
@ -192,6 +193,46 @@ struct FromImpl<torch::stable::Tensor> {
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for torch::headeronly::HeaderOnlyArrayRef<T> => StableIValue
|
||||
// Returns a new owning reference of the underlying list.
|
||||
template <typename T>
|
||||
struct FromImpl<torch::headeronly::HeaderOnlyArrayRef<T>> {
|
||||
static StableIValue call(
|
||||
const torch::headeronly::HeaderOnlyArrayRef<T>& val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
StableListHandle new_list_handle;
|
||||
try {
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
torch_new_list_reserve_size(val.size(), &new_list_handle));
|
||||
for (const auto& elem : val) {
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
torch_list_push_back(new_list_handle, from(elem)));
|
||||
}
|
||||
return from(new_list_handle);
|
||||
} catch (const std::runtime_error& e) {
|
||||
if (new_list_handle != nullptr) {
|
||||
// clean up memory if an error was thrown
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle));
|
||||
}
|
||||
throw;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for std::vector<T> => StableIValue, which is implemented the
|
||||
// same way as HeaderOnlyArrayRef<T> => StableIValue
|
||||
// Returns a new owning reference of the underlying list.
|
||||
template <typename T>
|
||||
struct FromImpl<std::vector<T>> {
|
||||
static StableIValue call(
|
||||
const std::vector<T>& val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
return from<torch::headeronly::HeaderOnlyArrayRef<T>>(val);
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// TO CONVERSIONS (StableIValue -> T)
|
||||
// =============================================================================
|
||||
@ -342,6 +383,38 @@ struct ToImpl<torch::stable::Tensor> {
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for StableIValue => std::vector<T>
|
||||
// std::vector<T> should be represented as a StableListHandle
|
||||
// filled with StableIValues
|
||||
// The new std::vector steals ownership of the underlying elements
|
||||
// and we free the underlying list referred by the input StableListHandle.
|
||||
template <typename T>
|
||||
struct ToImpl<std::vector<T>> {
|
||||
static std::vector<T> call(
|
||||
StableIValue val,
|
||||
[[maybe_unused]] uint64_t extension_build_version,
|
||||
[[maybe_unused]] bool is_internal) {
|
||||
auto list_handle = to<StableListHandle>(val);
|
||||
size_t size;
|
||||
try {
|
||||
TORCH_ERROR_CODE_CHECK(torch_list_size(list_handle, &size));
|
||||
std::vector<T> result;
|
||||
result.reserve(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
StableIValue element;
|
||||
TORCH_ERROR_CODE_CHECK(torch_list_get_item(list_handle, i, &element));
|
||||
result.push_back(to<T>(element));
|
||||
}
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
|
||||
return result;
|
||||
} catch (const std::runtime_error& e) {
|
||||
// clean up memory if an exception is thrown, and rethrow
|
||||
TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle));
|
||||
throw;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// end to helpers for converting between StableIValue and T
|
||||
// =============================================================================
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
|
||||
@ -441,15 +441,11 @@ def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
keepdim = args_schema[3] if len(args_schema) > 3 else False
|
||||
dims = _infer_reduction_dims(dim, input_strategy.ndim)
|
||||
reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims
|
||||
reduction_linear = all(
|
||||
all(not p.is_partial() for p in op_spec.output_spec.placements)
|
||||
for op_spec in input_strategy.strategies
|
||||
)
|
||||
return common_reduction_strategy(
|
||||
input_strategy,
|
||||
reduce_dims,
|
||||
keep_dim=cast(bool, keepdim),
|
||||
reduction_linear=reduction_linear,
|
||||
reduction_linear=True,
|
||||
reduction_op=NormReduction(norm_type),
|
||||
)
|
||||
|
||||
@ -472,14 +468,10 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
|
||||
if not isinstance(op_strategy, OpStrategy):
|
||||
raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}")
|
||||
reduce_dims = list(range(op_strategy.ndim))
|
||||
reduction_linear = all(
|
||||
all(not p.is_partial() for p in op_spec.output_spec.placements)
|
||||
for op_spec in op_strategy.strategies
|
||||
)
|
||||
output_strategy = common_reduction_strategy(
|
||||
op_strategy,
|
||||
reduce_dims,
|
||||
reduction_linear=reduction_linear,
|
||||
reduction_linear=True,
|
||||
reduction_op=NormReduction(norm_type),
|
||||
)
|
||||
output_tuple_strategy_children.append(output_strategy)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
|
||||
@ -33,6 +33,7 @@ __all__ = [
|
||||
"SelectiveCheckpointContext",
|
||||
"create_selective_checkpoint_contexts",
|
||||
"SAC_IGNORED_OPS",
|
||||
"GraphExecGroup",
|
||||
]
|
||||
|
||||
_DEFAULT_DETERMINISM_MODE = "default"
|
||||
@ -1072,7 +1073,7 @@ class _StopRecomputationError(Exception):
|
||||
|
||||
|
||||
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None:
|
||||
def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]) -> None:
|
||||
def pack_hook(x):
|
||||
x = x.detach() if x.requires_grad else x
|
||||
target_frame = target_frame_ref()
|
||||
@ -1145,10 +1146,14 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
return holder
|
||||
|
||||
def unpack_hook(holder):
|
||||
gid = torch._C._current_graph_task_id()
|
||||
if gid == -1:
|
||||
# generate a temporary id if we trigger unpack outside of a backward call
|
||||
gid = int(uuid.uuid4())
|
||||
# First check if we're inside a GraphExecGroup context
|
||||
gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
|
||||
if gid is None:
|
||||
# Fallback to using the current graph task id
|
||||
gid = torch._C._current_graph_task_id()
|
||||
if gid == -1:
|
||||
# generate a temporary id if we trigger unpack outside of a backward call
|
||||
gid = int(uuid.uuid4())
|
||||
|
||||
if not frame.is_recomputed[gid]:
|
||||
ctx = frame.input_saver.grad_fn
|
||||
@ -1168,10 +1173,17 @@ class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
|
||||
_internal_assert(gid in holder.handles)
|
||||
|
||||
if holder.handles[gid] is None:
|
||||
extra = ""
|
||||
if torch._C._get_graph_exec_group() is not None:
|
||||
extra = (
|
||||
"Performing two backward calls that overlap (i.e. require the same "
|
||||
"saved activation in order to compute gradients) is not allowed while "
|
||||
"under the torch.utils.checkpoint.GraphExecGroup context. "
|
||||
)
|
||||
raise CheckpointError(
|
||||
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
|
||||
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
|
||||
"so only once. Otherwise please open an issue with details on your use case."
|
||||
f"unpacked once. {extra}If you are calling ctx.saved_tensors in backward, make sure "
|
||||
"to do so only once. Otherwise please open an issue with details on your use case."
|
||||
)
|
||||
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
|
||||
ret = frame.recomputed[gid][holder.handles[gid]]
|
||||
@ -1594,6 +1606,40 @@ def _checkpoint_without_reentrant_generator(
|
||||
|
||||
return
|
||||
|
||||
|
||||
class GraphExecGroup:
|
||||
"""Any checkpointed regions encountered by backward under the same instance
|
||||
of this context manager will trigger recompute at most once, even if
|
||||
there are multiple calls to backward.
|
||||
|
||||
Backward calls under the same instance of this context manager must execute
|
||||
over non-overlapping regions of the backward graph even if retain_graph=True.
|
||||
In particular, any two backward call cannot use the same saved activation for
|
||||
gradient computation.
|
||||
|
||||
.. note::
|
||||
This context manager only affects checkpoint with use_reentrant=False, and
|
||||
is a no-op otherwise.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "GraphExecGroup":
|
||||
if torch._C._get_graph_exec_group() is not None:
|
||||
raise RuntimeError(
|
||||
"GraphExecGroup contexts cannot be nested. "
|
||||
f"Already inside group {torch._C._get_graph_exec_group()}"
|
||||
)
|
||||
torch._C._set_graph_exec_group(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
torch._C._set_graph_exec_group(None)
|
||||
|
||||
@classmethod
|
||||
def _get_current_group(cls) -> Optional["GraphExecGroup"]:
|
||||
# Private API to be used by utils like AC
|
||||
return torch._C._get_graph_exec_group()
|
||||
|
||||
|
||||
# Note: [compiled autograd and checkpoint unpack hook]
|
||||
# When tracing via compiled autograd, this hook will be visible to the
|
||||
# compiler if the forward of this checkpointed region ran in eager.
|
||||
|
||||
Reference in New Issue
Block a user