mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
2
.ci/docker/ci_commit_pins/huggingface-requirements.txt
Normal file
2
.ci/docker/ci_commit_pins/huggingface-requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
transformers==4.54.0
|
||||
soxr==0.5.0
|
||||
@ -1 +0,0 @@
|
||||
v4.54.0
|
||||
@ -5,9 +5,7 @@ set -ex
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
function install_huggingface() {
|
||||
local version
|
||||
commit=$(get_pinned_commit huggingface)
|
||||
pip_install "git+https://github.com/huggingface/transformers@${commit}"
|
||||
pip_install -r huggingface-requirements.txt
|
||||
}
|
||||
|
||||
function install_timm() {
|
||||
@ -26,9 +24,6 @@ function install_torchbench() {
|
||||
|
||||
python install.py --continue_on_fail
|
||||
|
||||
# soxr comes from https://github.com/huggingface/transformers/pull/39429
|
||||
pip install transformers==4.54.0 soxr==0.5.0
|
||||
|
||||
echo "Print all dependencies after TorchBench is installed"
|
||||
python -mpip freeze
|
||||
popd
|
||||
|
||||
@ -96,11 +96,11 @@ ARG ANACONDA_PYTHON_VERSION
|
||||
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
|
||||
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface.txt huggingface.txt
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
# (optional) Install non-default Ninja version
|
||||
ARG NINJA_VERSION
|
||||
|
||||
@ -56,10 +56,10 @@ RUN rm install_openssl.sh
|
||||
ARG INDUCTOR_BENCHMARKS
|
||||
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface.txt huggingface.txt
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
|
||||
|
||||
# Install XPU Dependencies
|
||||
ARG XPU_VERSION
|
||||
|
||||
@ -96,11 +96,11 @@ RUN rm install_openssl.sh
|
||||
ARG INDUCTOR_BENCHMARKS
|
||||
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface.txt huggingface.txt
|
||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
ARG TRITON
|
||||
ARG TRITON_CPU
|
||||
|
||||
@ -16,10 +16,11 @@ popd
|
||||
# enable debug asserts in serialization
|
||||
export TORCH_SERIALIZATION_DEBUG=1
|
||||
|
||||
__TEST_PYTHON_HAS_SETUP=''
|
||||
__TEST_PYTHON_HAS_SETUP='' # marker for `setup_test_python`
|
||||
|
||||
setup_test_python() {
|
||||
if [[ -n "${__TEST_PYTHON_HAS_SETUP}" ]]; then
|
||||
# Already set up, skip.
|
||||
return
|
||||
fi
|
||||
|
||||
@ -40,7 +41,7 @@ setup_test_python() {
|
||||
# might help with intermittent compiler test failures
|
||||
ulimit -n 16384
|
||||
|
||||
__TEST_PYTHON_HAS_SETUP=1
|
||||
__TEST_PYTHON_HAS_SETUP=1 # marker
|
||||
}
|
||||
|
||||
test_python_all() {
|
||||
@ -188,20 +189,20 @@ checkout_install_torchbench() {
|
||||
# to install and test other models
|
||||
python install.py --continue_on_fail
|
||||
fi
|
||||
popd
|
||||
|
||||
# soxr comes from https://github.com/huggingface/transformers/pull/39429
|
||||
pip install transformers==4.54.0 soxr==0.5.0
|
||||
|
||||
pip install -r .ci/docker/ci_commit_pins/huggingface-requirements.txt
|
||||
# https://github.com/pytorch/pytorch/issues/160689 to remove torchao because
|
||||
# its current version 0.12.0 doesn't work with transformers 4.54.0
|
||||
pip uninstall -y torchao
|
||||
|
||||
echo "Print all dependencies after TorchBench is installed"
|
||||
python -mpip freeze
|
||||
popd
|
||||
}
|
||||
|
||||
torchbench_setup_macos() {
|
||||
setup_test_python
|
||||
|
||||
git clone --recursive https://github.com/pytorch/vision torchvision
|
||||
git clone --recursive https://github.com/pytorch/audio torchaudio
|
||||
brew install jpeg-turbo libpng
|
||||
|
||||
20
.github/dependabot.yml
vendored
Normal file
20
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
version: 2
|
||||
updates:
|
||||
# Update to the latest transformers version with dependabot
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/.ci/docker/ci_commit_pins"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
target-branch: "main"
|
||||
allow:
|
||||
- dependency-name: "transformers"
|
||||
commit-message:
|
||||
prefix: "[Dependabot] Update"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "open source"
|
||||
- "python"
|
||||
- "topic: not user facing"
|
||||
- "module: ci"
|
||||
- "module: inductor"
|
||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@ -93,7 +93,7 @@ jobs:
|
||||
script: |
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
echo "Running mypy"
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY --all-files" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
||||
|
||||
lintrunner-noclang:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||
@ -111,9 +111,9 @@ jobs:
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
echo "Running all other linters"
|
||||
if [ "$CHANGED_FILES" = '*' ]; then
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY --all-files" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
||||
else
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
fi
|
||||
|
||||
quick-checks:
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <ATen/ops/is_set_to_native.h>
|
||||
#include <ATen/ops/size_native.h>
|
||||
#include <ATen/ops/stride_native.h>
|
||||
#include <ATen/ops/sym_is_contiguous_native.h>
|
||||
#include <ATen/ops/sym_numel_native.h>
|
||||
#include <ATen/ops/sym_size_native.h>
|
||||
#include <ATen/ops/sym_storage_offset_native.h>
|
||||
@ -57,6 +58,12 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) {
|
||||
return self.sym_size(dim);
|
||||
}
|
||||
|
||||
c10::SymBool sym_is_contiguous(
|
||||
const Tensor& self,
|
||||
c10::MemoryFormat memory_format) {
|
||||
return self.sym_is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
c10::SymInt sym_stride(const Tensor& self, int64_t dim) {
|
||||
return self.sym_stride(dim);
|
||||
}
|
||||
|
||||
@ -5509,6 +5509,13 @@
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
tags: core
|
||||
manual_cpp_binding: True
|
||||
|
||||
- func: sym_numel(Tensor self) -> SymInt
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
|
||||
@ -313,8 +313,15 @@ void TensorImpl::throw_data_ptr_access_error() const {
|
||||
c10::SymBool TensorImpl::sym_is_contiguous_custom(
|
||||
at::MemoryFormat memory_format) const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
|
||||
this, memory_format);
|
||||
// TO reduce BC breaking and reduce having to introduce
|
||||
// sym_is_contiguous. call is_contiguous when tensor does not
|
||||
if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
|
||||
return pyobj_slot_.load_pyobj_interpreter()->sym_is_contiguous(
|
||||
this, memory_format);
|
||||
} else {
|
||||
return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
|
||||
this, memory_format);
|
||||
}
|
||||
}
|
||||
|
||||
return sym_is_contiguous_default(memory_format);
|
||||
|
||||
@ -60,6 +60,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
|
||||
PANIC(is_contiguous);
|
||||
}
|
||||
c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat)
|
||||
const override {
|
||||
PANIC(sym_is_contiguous);
|
||||
}
|
||||
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
|
||||
const override {
|
||||
PANIC(is_strides_like);
|
||||
|
||||
@ -168,6 +168,9 @@ struct C10_API PyInterpreterVTable {
|
||||
|
||||
virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
|
||||
const = 0;
|
||||
virtual c10::SymBool sym_is_contiguous(
|
||||
const TensorImpl* self,
|
||||
at::MemoryFormat) const = 0;
|
||||
virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
|
||||
const = 0;
|
||||
virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
|
||||
|
||||
@ -208,6 +208,7 @@ xfail_not_implemented = {
|
||||
"aten::subtract_.Scalar",
|
||||
"aten::subtract_.Tensor",
|
||||
"aten::svd.U",
|
||||
"aten::sym_is_contiguous",
|
||||
"aten::sym_size.int",
|
||||
"aten::sym_stride.int",
|
||||
"aten::sym_numel",
|
||||
|
||||
@ -1958,6 +1958,8 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
||||
return contiguous_data.is_contiguous()
|
||||
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
|
||||
return torch.ops.aten.sym_is_contiguous(contiguous_data)
|
||||
return NotImplemented
|
||||
|
||||
class ExampleTensor3(torch.Tensor):
|
||||
@ -1971,6 +1973,8 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
||||
return not_contiguous_data.is_contiguous()
|
||||
if func.overloadpacket == torch.ops.aten.sym_is_contiguous:
|
||||
return torch.ops.aten.sym_is_contiguous(not_contiguous_data)
|
||||
return NotImplemented
|
||||
|
||||
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
|
||||
@ -2003,6 +2007,7 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func in [
|
||||
torch.ops.aten.sym_is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.is_strides_like_format.default,
|
||||
|
||||
@ -97,6 +97,7 @@ _SKIP_PYTHON_BINDINGS = [
|
||||
"is_sparse_csr",
|
||||
"size",
|
||||
"stride",
|
||||
"sym_is_contiguous",
|
||||
"sym_size",
|
||||
"sym_stride",
|
||||
"sym_storage_offset",
|
||||
|
||||
@ -1560,7 +1560,6 @@ class CatchErrorsWrapper:
|
||||
frame_state: dict[str, Union[int, FrameStateSizeEntry]],
|
||||
) -> ConvertFrameReturn:
|
||||
assert frame_state is not None
|
||||
|
||||
input_codes.add(frame.f_code)
|
||||
|
||||
is_skipfile = trace_rules.check(frame.f_code)
|
||||
|
||||
@ -1365,10 +1365,13 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
)
|
||||
|
||||
name = f"optimus_{str(rule.__class__.__name__)}"
|
||||
if "MTIA" in name:
|
||||
name = f"cff_{str(rule.__class__.__name__)}"
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": f"optimus_{str(rule.__class__.__name__)}",
|
||||
"name": name,
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: graph_str,
|
||||
|
||||
@ -265,12 +265,14 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
is_nested_int,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
def eval_eager(x):
|
||||
return bool(x)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
|
||||
|
||||
if maybe_guard_or_false(a.numel() < 2):
|
||||
return True
|
||||
@ -305,14 +307,13 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
|
||||
if a.ndim != 4:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
def eval_eager(x):
|
||||
return bool(x)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 3, 2, 0):
|
||||
@ -334,14 +335,13 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
|
||||
if a.ndim != 5:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
def eval_eager(x):
|
||||
return bool(x)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 4, 3, 2, 0):
|
||||
@ -406,7 +406,7 @@ def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool:
|
||||
|
||||
|
||||
# similar to is_contiguous_for_memory_format but return false on data dependency.
|
||||
def contiguous_for_memory_format_or_false( # type: ignore[return]
|
||||
def is_contiguous_for_memory_format_or_false( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
return is_contiguous_for_memory_format(
|
||||
@ -550,11 +550,14 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||
is_contiguous = True
|
||||
is_channels_last = True
|
||||
for t in tensors:
|
||||
is_contiguous = is_contiguous and contiguous_for_memory_format_or_false(
|
||||
is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false(
|
||||
t, memory_format=torch.contiguous_format
|
||||
)
|
||||
is_channels_last = is_channels_last and contiguous_for_memory_format_or_false(
|
||||
t, memory_format=torch.channels_last
|
||||
is_channels_last = (
|
||||
is_channels_last
|
||||
and is_contiguous_for_memory_format_or_false(
|
||||
t, memory_format=torch.channels_last
|
||||
)
|
||||
)
|
||||
|
||||
if is_contiguous and not is_channels_last:
|
||||
|
||||
@ -19,7 +19,6 @@ import torch.utils._pytree as pytree
|
||||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
BoolLike,
|
||||
contiguous_for_memory_format_or_false,
|
||||
DeviceLikeType,
|
||||
Dim,
|
||||
DimsSequenceType,
|
||||
@ -29,6 +28,7 @@ from torch._prims_common import (
|
||||
FloatLike,
|
||||
FloatWithoutSymFloat,
|
||||
IntLike,
|
||||
is_contiguous_for_memory_format_or_false,
|
||||
is_contiguous_or_false,
|
||||
is_weakly_lesser_type,
|
||||
Number,
|
||||
@ -3000,7 +3000,7 @@ def contiguous(
|
||||
)
|
||||
|
||||
# TODO: make logic consistent with aten contiguous
|
||||
if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
|
||||
if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format):
|
||||
return a
|
||||
|
||||
return torch.clone(a, memory_format=memory_format)
|
||||
|
||||
@ -15,11 +15,11 @@ import torch._prims_common as utils
|
||||
from torch._dispatch.python import no_python_dispatcher
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import (
|
||||
contiguous_for_memory_format_or_false,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
is_boolean_dtype,
|
||||
is_contiguous,
|
||||
is_contiguous_for_memory_format_or_false,
|
||||
is_contiguous_or_false,
|
||||
is_float_dtype,
|
||||
is_integer_dtype,
|
||||
@ -1256,13 +1256,13 @@ def make_fast_binary_impl(
|
||||
continue
|
||||
definitely_contiguous = (
|
||||
definitely_contiguous
|
||||
and contiguous_for_memory_format_or_false(
|
||||
and is_contiguous_for_memory_format_or_false(
|
||||
op, memory_format=torch.contiguous_format
|
||||
)
|
||||
)
|
||||
definitely_channels_last = (
|
||||
definitely_channels_last
|
||||
and contiguous_for_memory_format_or_false(
|
||||
and is_contiguous_for_memory_format_or_false(
|
||||
op, memory_format=torch.channels_last
|
||||
)
|
||||
)
|
||||
|
||||
@ -82,6 +82,8 @@ struct ConcretePyInterpreterVTable final
|
||||
|
||||
bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
|
||||
const override;
|
||||
c10::SymBool sym_is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
|
||||
const override;
|
||||
bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
|
||||
const override;
|
||||
bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
|
||||
@ -476,6 +478,33 @@ bool ConcretePyInterpreterVTable::is_contiguous(
|
||||
return PyObject_IsTrue(out.ptr());
|
||||
}
|
||||
|
||||
c10::SymBool ConcretePyInterpreterVTable::sym_is_contiguous(
|
||||
const c10::TensorImpl* self,
|
||||
at::MemoryFormat memory_format) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
py::object out;
|
||||
out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"sym_is_contiguous",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("sym_is_contiguous")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten",
|
||||
{py::cast(memory_format)});
|
||||
|
||||
if (out.is_none()) {
|
||||
return self->sym_is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
return torch::is_symbool(out) ? out.cast<c10::SymBool>()
|
||||
: c10::SymBool{py::cast<bool>(out)};
|
||||
}
|
||||
|
||||
bool ConcretePyInterpreterVTable::is_strides_like(
|
||||
const c10::TensorImpl* self,
|
||||
at::MemoryFormat memory_format) const {
|
||||
|
||||
@ -33,6 +33,7 @@ using c10::StorageType;
|
||||
using c10::StreamObjType;
|
||||
using c10::StringType;
|
||||
using c10::Symbol;
|
||||
using c10::SymBoolType;
|
||||
using c10::SymIntType;
|
||||
using c10::TensorType;
|
||||
using c10::TupleType;
|
||||
@ -66,6 +67,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||
{"int", c10::TypeFactory::get<IntType>()},
|
||||
{"SymInt", c10::TypeFactory::get<SymIntType>()},
|
||||
{"bool", c10::TypeFactory::get<BoolType>()},
|
||||
{"SymBool", c10::TypeFactory::get<SymBoolType>()},
|
||||
{"None", c10::TypeFactory::get<NoneType>()},
|
||||
{"NoneType", c10::TypeFactory::get<NoneType>()},
|
||||
{"Capsule", c10::TypeFactory::get<CapsuleType>()},
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torch.fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import contiguous_for_memory_format_or_false
|
||||
from torch._prims_common import is_contiguous_for_memory_format_or_false
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_aggregate, Node
|
||||
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
|
||||
torch.channels_last_3d,
|
||||
}
|
||||
for query_format in memory_formats:
|
||||
if contiguous_for_memory_format_or_false(
|
||||
if is_contiguous_for_memory_format_or_false(
|
||||
result, memory_format=query_format
|
||||
):
|
||||
memory_format = query_format
|
||||
|
||||
@ -285,7 +285,9 @@ def layout(func, *args, **kwargs):
|
||||
return _get_data(args[0]).layout
|
||||
|
||||
|
||||
@register_dispatch_func([torch.ops.aten.is_contiguous])
|
||||
@register_dispatch_func(
|
||||
[torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous]
|
||||
)
|
||||
def is_contiguous(func, *args, **kwargs):
|
||||
data = _get_data(args[0])
|
||||
if data.is_sparse:
|
||||
|
||||
@ -234,14 +234,25 @@ class NestedTensor(torch.Tensor):
|
||||
mt = self._min_seqlen_tensor
|
||||
return None if mt is None else _load_val_from_tensor(mt)
|
||||
|
||||
def _is_contiguous_or_false(self):
|
||||
if self.lengths() is not None:
|
||||
return False
|
||||
from torch._prims_common import is_contiguous_for_memory_format_or_false
|
||||
|
||||
return is_contiguous_for_memory_format_or_false(
|
||||
self._values, memory_format=torch.contiguous_format
|
||||
)
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
||||
)
|
||||
|
||||
if self.grad_fn:
|
||||
grad_fn_str = f", grad_fn={self.grad_fn}"
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
|
||||
|
||||
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})"
|
||||
|
||||
# TODO: Remove this in favor of the default tensor subclass serialization logic.
|
||||
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
|
||||
|
||||
@ -516,6 +516,29 @@ register_jagged_func(
|
||||
)(is_contiguous_general)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?"
|
||||
)
|
||||
def sym_is_contiguous_general(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
|
||||
# If created from narrow() check for lengths
|
||||
if inp.lengths() is not None:
|
||||
return False
|
||||
|
||||
new_kwargs["memory_format"] = new_kwargs.get(
|
||||
"memory_format", torch.contiguous_format
|
||||
)
|
||||
|
||||
if new_kwargs["memory_format"] == torch.preserve_format:
|
||||
return True
|
||||
|
||||
return torch.ops.aten.sym_is_contiguous.default(inp._values, **new_kwargs)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
|
||||
)
|
||||
|
||||
@ -834,7 +834,8 @@ class _FlopCounterMode(TorchDispatchMode):
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
# Skip ops from non-standard dispatch_sizes_strides_policy such as NJT
|
||||
if func in {torch.ops.aten.is_contiguous.default,
|
||||
if func in {torch.ops.aten.sym_is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.is_strides_like_format.default,
|
||||
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||
|
||||
@ -79,6 +79,7 @@ tensorOptionsT = BaseCppType("at", "TensorOptions")
|
||||
typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
|
||||
tensorGeometryT = BaseCppType("at", "TensorGeometry")
|
||||
SymIntT = BaseCppType("c10", "SymInt")
|
||||
SymBoolT = BaseCppType("c10", "SymBool")
|
||||
symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
|
||||
|
||||
# Types representing template parameters. Technically, we probably shouldn't
|
||||
@ -125,6 +126,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.Storage: storageT,
|
||||
BaseTy.Stream: streamT,
|
||||
BaseTy.SymInt: SymIntT,
|
||||
BaseTy.SymBool: SymBoolT,
|
||||
}
|
||||
|
||||
# CTypes encode C++ type structure as needed for translation.
|
||||
|
||||
Reference in New Issue
Block a user