mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Compare commits
11 Commits
Author | SHA1 | Date | |
---|---|---|---|
f2d7f235a6 | |||
402b289f3b | |||
a32157c67c | |||
24e7f29099 | |||
5b5d269d34 | |||
fa88f390a0 | |||
fe39c07826 | |||
cba195c8ed | |||
16e67be7f1 | |||
7afffdf48b | |||
ca45649eb5 |
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,fail_accuracy,8
|
||||
|
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
||||
|
||||
|
||||
|
||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
||||
pyhpc_isoneutral_mixing,pass,0
|
||||
|
||||
|
||||
|
||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
yolov3,fail_to_run,0
|
||||
yolov3,pass,0
|
||||
|
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
||||
|
||||
|
||||
|
||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
||||
pyhpc_isoneutral_mixing,pass,0
|
||||
|
||||
|
||||
|
||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
yolov3,fail_to_run,0
|
||||
yolov3,pass,0
|
||||
|
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,fail_accuracy,8
|
||||
|
|
@ -298,4 +298,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,pass,8
|
||||
|
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,pass,8
|
||||
|
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,pass,8
|
||||
|
|
@ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316
|
||||
basic_gnn_gcn,float32,dynamic,default,1.074576405
|
||||
llama_v2_7b_16h,float32,dynamic,default,1.211740245
|
||||
resnet50,float32,dynamic,default,1.65984261
|
||||
timm_efficientnet,float32,static,cpp,2.271561735
|
||||
#timm_efficientnet,float32,static,cpp,2.1938112
|
||||
mobilenet_v3_large,float32,static,cpp,2.63375628
|
||||
timm_resnest,float32,dynamic,cpp,1.67998548
|
||||
pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463
|
||||
#hf_GPT2,float32,dynamic,cpp,
|
||||
hf_GPT2,float32,dynamic,cpp,1.379885175
|
||||
#hf_GPT2,float32,dynamic,cpp,1.292704418
|
||||
resnext50_32x4d,amp,static,default,1.461687045
|
||||
vgg16,amp,static,default,1.267194285
|
||||
hf_Longformer,amp,dynamic,default,0.997006035
|
||||
@ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146
|
||||
llama,amp,static,default,1.32950568
|
||||
timm_regnet,amp,static,cpp,1.157188305
|
||||
lennard_jones,amp,static,cpp,2.240104485
|
||||
hf_T5_generate,amp,dynamic,cpp,1.447656135
|
||||
#hf_T5_generate,amp,dynamic,cpp,1.29339502
|
||||
timm_vovnet,amp,dynamic,cpp,1.07856471
|
||||
mobilenet_v2,amp,dynamic,cpp,2.27774577
|
||||
|
|
@ -25,10 +25,6 @@ from torch._dynamo.utils import clone_inputs
|
||||
# We are primarily interested in tf32 datatype
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Enable FX graph caching
|
||||
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
|
||||
|
||||
def _reassign_parameters(model):
|
||||
# torch_geometric models register parameter as tensors due to
|
||||
|
@ -827,6 +827,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/mtia/Module.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
|
@ -62,8 +62,8 @@ Overall, the ``pipelining`` package provides the following features:
|
||||
application on the Llama model.
|
||||
|
||||
|
||||
Step 1: build ``PipelineStage`` for execution
|
||||
*********************************************
|
||||
Step 1: build ``PipelineStage``
|
||||
*******************************
|
||||
|
||||
Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage``
|
||||
objects that wrap the part of the model running in that stage. The
|
||||
|
@ -304,13 +304,12 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
|
||||
def test_torch_guards_stack_frame_register_inlining_disable(self):
|
||||
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
|
||||
x = torch.tensor([0.5, 0.5])
|
||||
|
||||
class encoder(torch.nn.Module):
|
||||
def __init__(self, y):
|
||||
super().__init__()
|
||||
self.register_parameter("param", y)
|
||||
self.a = y
|
||||
|
||||
@torch._dynamo.disable
|
||||
def helper(self, x, y):
|
||||
@ -318,9 +317,9 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def forward(self, a, *args):
|
||||
x = a + a
|
||||
return self.helper(x, self.param)
|
||||
return self.helper(x, self.a)
|
||||
|
||||
e = encoder(y)
|
||||
e = encoder(2.0)
|
||||
|
||||
seen_frames = []
|
||||
import contextlib
|
||||
|
@ -56,7 +56,7 @@ class BinaryFoldingTemplate(TestCase):
|
||||
self.use_scalar = scalar
|
||||
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
|
||||
tensor_size[1] = self.conv.weight.size(0)
|
||||
self.tensor = (
|
||||
self.tensor = torch.nn.Parameter(
|
||||
add_tensor
|
||||
if add_tensor is not None
|
||||
else torch.rand(tensor_size).to(device)
|
||||
@ -136,7 +136,11 @@ class BinaryFoldingTemplate(TestCase):
|
||||
nn.Conv2d,
|
||||
pytorch_op,
|
||||
False,
|
||||
add_tensor=torch.rand(32, 1, 32).to(self.device),
|
||||
add_tensor=torch.rand(
|
||||
32,
|
||||
1,
|
||||
32,
|
||||
).to(self.device),
|
||||
expect_success=False,
|
||||
)
|
||||
|
||||
@ -156,7 +160,7 @@ class BinaryFoldingTemplate(TestCase):
|
||||
nn.Conv2d,
|
||||
pytorch_op,
|
||||
False,
|
||||
add_tensor=torch.tensor([2]).to(torch.int).to(self.device),
|
||||
add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
|
||||
expect_success=False,
|
||||
)
|
||||
|
||||
|
@ -233,6 +233,23 @@ def run_fw_bw_and_get_code(fn):
|
||||
return run_and_get_code(run_with_backward)
|
||||
|
||||
|
||||
def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl):
|
||||
for _op_name in op_set:
|
||||
qualified_op_name = f"{ns}::{_op_name}"
|
||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
||||
for overload_name in overload_names:
|
||||
try:
|
||||
reg_op_name = qualified_op_name
|
||||
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
||||
if schema.overload_name:
|
||||
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
||||
reg_op_name, dispatch_key
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
|
||||
class TestCase(InductorTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -751,6 +768,58 @@ class CommonTemplate:
|
||||
),
|
||||
)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_eager_aoti_support_out(self):
|
||||
ns = "aten"
|
||||
op_name = "clamp"
|
||||
dispatch_key = "CPU"
|
||||
device = "cpu"
|
||||
if self.device.lower() == "cuda":
|
||||
dispatch_key = "CUDA"
|
||||
device = "cuda"
|
||||
|
||||
inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0)
|
||||
min_tensor = inp_tensor - 0.05
|
||||
max_tensor = inp_tensor + 0.05
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||
-1
|
||||
)
|
||||
ref_tensor = torch.clamp(
|
||||
max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor
|
||||
)
|
||||
|
||||
ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||
-1
|
||||
)
|
||||
ref_tensor1 = torch.clamp(
|
||||
max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor
|
||||
)
|
||||
|
||||
register_ops_with_aoti_compile(
|
||||
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||
)
|
||||
|
||||
res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||
-1
|
||||
)
|
||||
res_tensor = torch.clamp(
|
||||
max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor
|
||||
)
|
||||
|
||||
self.assertEqual(ref_tensor, res_tensor)
|
||||
self.assertEqual(ref_out_tensor, res_out_tensor)
|
||||
|
||||
res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||
-1
|
||||
)
|
||||
res_tensor1 = torch.clamp(
|
||||
max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor
|
||||
)
|
||||
|
||||
self.assertEqual(ref_tensor1, res_tensor1)
|
||||
self.assertEqual(ref_out_tensor1, res_out_tensor1)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_eager_aoti_cache_hit(self):
|
||||
ns = "aten"
|
||||
@ -779,24 +848,13 @@ class CommonTemplate:
|
||||
with mock.patch(
|
||||
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
|
||||
):
|
||||
qualified_op_name = f"{ns}::{op_name}"
|
||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
||||
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
# Get ref result from eager
|
||||
ref_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
||||
|
||||
for overload_name in overload_names:
|
||||
try:
|
||||
reg_op_name = qualified_op_name
|
||||
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
||||
if schema.overload_name:
|
||||
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
||||
reg_op_name, dispatch_key
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
register_ops_with_aoti_compile(
|
||||
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||
)
|
||||
|
||||
# Invoke the pre-compiled kernel and get result.
|
||||
res_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
||||
@ -804,7 +862,7 @@ class CommonTemplate:
|
||||
self.assertEqual(ref_value, res_value)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_compile_with_persistent_cache(self):
|
||||
def test_eager_aoti_with_persistent_cache(self):
|
||||
def fn(a):
|
||||
return torch.abs(a)
|
||||
|
||||
@ -906,19 +964,9 @@ class CommonTemplate:
|
||||
for scalar_value in scalar_values:
|
||||
ref_values.append(torch.add(a, b, alpha=scalar_value))
|
||||
|
||||
qualified_op_name = f"{namespace_name}::{op_name}"
|
||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
||||
for overload_name in overload_names:
|
||||
try:
|
||||
reg_op_name = qualified_op_name
|
||||
schema = torch._C._get_schema(reg_op_name, overload_name)
|
||||
if schema.overload_name:
|
||||
reg_op_name = f"{reg_op_name}.{schema.overload_name}"
|
||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
||||
reg_op_name, dispatch_key
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
register_ops_with_aoti_compile(
|
||||
namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||
)
|
||||
|
||||
res_values = []
|
||||
for scalar_value in scalar_values:
|
||||
@ -928,8 +976,7 @@ class CommonTemplate:
|
||||
self.assertEqual(ref_values, res_values)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_torch_compile_override_registration(self):
|
||||
dynamic = False
|
||||
def test_eager_aoti_override_registration(self):
|
||||
namespace_name = "aten"
|
||||
dispatch_key = "CPU"
|
||||
device = torch.device("cpu")
|
||||
@ -951,24 +998,10 @@ class CommonTemplate:
|
||||
ref = opt_fn(x)
|
||||
ref_array.append(ref)
|
||||
|
||||
def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl):
|
||||
for _op_name in op_set:
|
||||
qualified_op_name = f"{namespace_name}::{_op_name}"
|
||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
||||
for overload_name in overload_names:
|
||||
try:
|
||||
reg_op_name = qualified_op_name
|
||||
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
||||
if schema.overload_name:
|
||||
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
||||
reg_op_name, dispatch_key
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl)
|
||||
register_ops_with_aoti_compile(
|
||||
namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl
|
||||
)
|
||||
|
||||
res_array = []
|
||||
for unary_op_name in unary_op_set:
|
||||
@ -985,7 +1018,9 @@ class CommonTemplate:
|
||||
ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
|
||||
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl)
|
||||
register_ops_with_aoti_compile(
|
||||
namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl
|
||||
)
|
||||
res_with_min = torch.ops.aten.clamp(a, min_tensor)
|
||||
res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
|
||||
self.assertEqual(ref_with_min, res_with_min)
|
||||
@ -7816,6 +7851,95 @@ class CommonTemplate:
|
||||
)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
def test_avg_pool3d_backward(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
a,
|
||||
b,
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[0, 0, 0],
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
torch.randn([2, 4, 7, 7, 7]),
|
||||
torch.randn([2, 4, 14, 14, 14]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_avg_pool3d_backward2(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
a,
|
||||
b,
|
||||
[3, 3, 3],
|
||||
[1, 1, 1],
|
||||
[1, 1, 1],
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
torch.randn([1, 1, 20, 20, 15]),
|
||||
torch.randn([1, 1, 20, 20, 15]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_avg_pool3d_backward3(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
a,
|
||||
b,
|
||||
[1, 1, 1],
|
||||
[2, 2, 2],
|
||||
[0, 0, 0],
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
torch.randn([1, 2016, 11, 11, 11]),
|
||||
torch.randn([1, 2016, 21, 21, 21]),
|
||||
],
|
||||
)
|
||||
assertGeneratedKernelCountEqual(self, 1)
|
||||
|
||||
def test_avg_pool3d_backward4(self):
|
||||
def fn(a, b):
|
||||
return aten.avg_pool3d_backward(
|
||||
a,
|
||||
b,
|
||||
[13, 13, 13],
|
||||
[1, 1, 1],
|
||||
[0, 0, 0],
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
torch._inductor.metrics.generated_kernel_count = 0
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
torch.randn([1, 16, 12, 12, 12]),
|
||||
torch.randn([1, 16, 24, 24, 24]),
|
||||
],
|
||||
check_lowp=False,
|
||||
)
|
||||
assertGeneratedKernelCountEqual(self, 0)
|
||||
|
||||
@config.patch(search_autotune_cache=False)
|
||||
def test_mm_views(self):
|
||||
def fn(a, b):
|
||||
|
@ -146,6 +146,7 @@ test_failures = {
|
||||
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
"test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
"test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
"test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
"test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||
|
@ -8,6 +8,7 @@ import os
|
||||
import sys
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
@ -369,6 +370,47 @@ class TestInductorDynamic(TestCase):
|
||||
arg = torch.tensor(5, device=device)
|
||||
self.assertEqual(f(arg), cf(arg))
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
||||
)
|
||||
@torch._inductor.config.patch(implicit_fallbacks=True)
|
||||
def test_unbacked_save_for_backwards(self, device) -> None:
|
||||
@torch.library.custom_op("_test::_cat", mutates_args=())
|
||||
def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
|
||||
return t * t.new_ones([sum(ds)])
|
||||
|
||||
@torch.library.register_fake("_test::_cat")
|
||||
def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
|
||||
[torch._check_is_size(d) for d in ds]
|
||||
return t.new_empty([sum(ds)])
|
||||
|
||||
def _cat_setup_context(ctx, inputs, output):
|
||||
pass
|
||||
|
||||
def _cat_backward(ctx, grad):
|
||||
return grad.sum(), None
|
||||
|
||||
torch.library.register_autograd(
|
||||
"_test::_cat",
|
||||
_cat_backward,
|
||||
setup_context=_cat_setup_context,
|
||||
)
|
||||
|
||||
def fn(t, sizes):
|
||||
r = torch.ops._test._cat(t, sizes.tolist())
|
||||
return r * t
|
||||
|
||||
t = torch.randn((), requires_grad=True, device=device)
|
||||
sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu")
|
||||
out = fn(t, sizes)
|
||||
out.sum().backward()
|
||||
expect = t.grad
|
||||
t.grad = None
|
||||
torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)(
|
||||
t, sizes
|
||||
).sum().backward()
|
||||
self.assertEqual(t.grad, expect)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_unbacked_reduction(self, device):
|
||||
expect_fail = device == "cpu" and not IS_ARM64
|
||||
|
@ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ...
|
||||
# Defined in torch/csrc/jit/runtime/static/init.cpp
|
||||
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
|
||||
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/fx/node.cpp
|
||||
class _NodeBase:
|
||||
_erased: _bool
|
||||
_prev: "_NodeBase"
|
||||
_next: "_NodeBase"
|
||||
|
||||
class _NodeIter(Iterator):
|
||||
def __init__(self, root: _NodeBase, reversed: _bool) -> None: ...
|
||||
def __iter__(self) -> Iterator[_NodeBase]: ...
|
||||
def __next__(self) -> _NodeBase: ...
|
||||
|
@ -752,7 +752,13 @@ class OutputGraph:
|
||||
**options,
|
||||
):
|
||||
if is_dynamic_nn_module(target, self.root_tx.export):
|
||||
return variables.UnspecializedNNModuleVariable(target, **options)
|
||||
result = variables.UnspecializedNNModuleVariable(target, **options)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(target)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
return self.root_tx.output.side_effects.track_object_existing(
|
||||
target, result
|
||||
)
|
||||
|
||||
options = dict(options)
|
||||
assert "source" in options
|
||||
|
@ -1128,6 +1128,19 @@ class VariableBuilder:
|
||||
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
|
||||
# created dynamically, don't specialize on it
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
if (
|
||||
torch._dynamo.config.inline_inbuilt_nn_modules
|
||||
and torch._inductor.config.freezing
|
||||
and not torch.is_grad_enabled()
|
||||
):
|
||||
from ..decorators import mark_static_address
|
||||
|
||||
for p in value.parameters():
|
||||
mark_static_address(p)
|
||||
|
||||
for b in value.buffers():
|
||||
mark_static_address(b)
|
||||
|
||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
|
@ -201,11 +201,19 @@ def _unlift_graph(mod, gm, graph_signature):
|
||||
|
||||
outputs = list(gm.graph.nodes)[-1].args[0]
|
||||
mutated_outputs = []
|
||||
for out in outputs:
|
||||
if out.name in graph_signature.buffers_to_mutate:
|
||||
mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
|
||||
else:
|
||||
mutated_outputs.append(None)
|
||||
buffer_mutations = graph_signature.buffers_to_mutate
|
||||
user_input_mutations = graph_signature.user_inputs_to_mutate
|
||||
output_tokens = graph_signature.output_tokens
|
||||
for idx, out in enumerate(outputs):
|
||||
value = None
|
||||
|
||||
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
|
||||
if out.name in buffer_mutations:
|
||||
value = buffer_mutations[out.name]
|
||||
elif out.name in user_input_mutations:
|
||||
value = user_input_mutations[out.name]
|
||||
|
||||
mutated_outputs.append(value)
|
||||
|
||||
unlifted_gm = _unlift(
|
||||
gm,
|
||||
|
@ -2155,7 +2155,6 @@ make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
|
||||
|
||||
|
||||
# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
|
||||
make_fallback(aten.avg_pool3d_backward)
|
||||
make_fallback(aten.max_pool3d_with_indices_backward)
|
||||
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
|
||||
make_fallback(aten._adaptive_avg_pool3d_backward)
|
||||
@ -4034,11 +4033,32 @@ def pad_adaptive_loader(x, pad_val=0.0):
|
||||
return load
|
||||
|
||||
|
||||
def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
|
||||
h_start_index_fn, w_start_index_fn = start_index_fns
|
||||
h_end_index_fn, w_end_index_fn = end_index_fns
|
||||
def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
|
||||
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
|
||||
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
|
||||
|
||||
def fn_sum(idx, loader):
|
||||
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
|
||||
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
|
||||
|
||||
return h_start_index, h_end_index, w_start_index, w_end_index
|
||||
|
||||
|
||||
def _adaptive_pooling_fn(
|
||||
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
|
||||
):
|
||||
h_in, w_in = in_sizes
|
||||
h_out, w_out = out_sizes
|
||||
|
||||
(
|
||||
h_start_index_fn,
|
||||
h_end_index_fn,
|
||||
w_start_index_fn,
|
||||
w_end_index_fn,
|
||||
) = compute_indices_adaptive_pooling(
|
||||
start_index, end_index, h_in, w_in, h_out, w_out
|
||||
)
|
||||
|
||||
def fn(idx, loader):
|
||||
*prefix, bh, bw = idx
|
||||
|
||||
h_start_index = h_start_index_fn(bh)
|
||||
@ -4047,7 +4067,7 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
|
||||
w_start_index = w_start_index_fn(bw)
|
||||
w_end_index = w_end_index_fn(bw)
|
||||
|
||||
total = None
|
||||
result = None
|
||||
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
||||
val = loader(
|
||||
prefix,
|
||||
@ -4055,13 +4075,66 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
|
||||
[h_start_index, w_start_index],
|
||||
[h_end_index, w_end_index],
|
||||
)
|
||||
if total is None:
|
||||
total = val
|
||||
if result is None:
|
||||
result = val
|
||||
else:
|
||||
total = ops.add(val, total)
|
||||
return total
|
||||
result = pooling_fn(val, result)
|
||||
return result
|
||||
|
||||
return fn_sum
|
||||
return fn
|
||||
|
||||
|
||||
def _adaptive_pooling_fn_with_idx(
|
||||
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
|
||||
):
|
||||
h_in, w_in = in_sizes
|
||||
h_out, w_out = out_sizes
|
||||
|
||||
(
|
||||
h_start_index_fn,
|
||||
h_end_index_fn,
|
||||
w_start_index_fn,
|
||||
w_end_index_fn,
|
||||
) = compute_indices_adaptive_pooling(
|
||||
start_index, end_index, h_in, w_in, h_out, w_out
|
||||
)
|
||||
|
||||
def fn(idx, loader):
|
||||
*prefix, bh, bw = idx
|
||||
|
||||
h_start_index = h_start_index_fn(bh)
|
||||
h_end_index = h_end_index_fn(bh)
|
||||
|
||||
w_start_index = w_start_index_fn(bw)
|
||||
w_end_index = w_end_index_fn(bw)
|
||||
|
||||
maxval = None
|
||||
maxindex = None
|
||||
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
||||
val = loader(
|
||||
prefix,
|
||||
[ih, iw],
|
||||
[h_start_index, w_start_index],
|
||||
[h_end_index, w_end_index],
|
||||
)
|
||||
|
||||
index = ops.index_expr(
|
||||
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
|
||||
)
|
||||
|
||||
if maxindex is None:
|
||||
maxindex = index
|
||||
else:
|
||||
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
|
||||
|
||||
if maxval is None:
|
||||
maxval = val
|
||||
else:
|
||||
maxval = pooling_fn(val, maxval)
|
||||
|
||||
return maxindex
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
fallback_adaptive_avg_pool2d = fallback_handler(
|
||||
@ -4099,27 +4172,24 @@ def _adaptive_avg_pool2d(x, output_size):
|
||||
new_size = list(batch) + [h_out, w_out]
|
||||
dtype = x.get_dtype()
|
||||
|
||||
window_size = h_kernel_max * w_kernel_max
|
||||
if window_size > 25:
|
||||
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
||||
return fallback_adaptive_avg_pool2d(x, output_size)
|
||||
|
||||
def start_index(index, out_dim, inp_dim):
|
||||
return FloorDiv((index * inp_dim), out_dim)
|
||||
|
||||
def end_index(index, out_dim, inp_dim):
|
||||
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
||||
|
||||
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
|
||||
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
|
||||
|
||||
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
|
||||
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
|
||||
|
||||
window_size = h_kernel_max * w_kernel_max
|
||||
if window_size > 25:
|
||||
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
||||
return fallback_adaptive_avg_pool2d(x, output_size)
|
||||
|
||||
fn_sum = _adaptive_pooling_idx_sum(
|
||||
[h_kernel_max, w_kernel_max],
|
||||
[h_start_index, w_start_index],
|
||||
[h_end_index, w_end_index],
|
||||
fn_sum = _adaptive_pooling_fn(
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||
in_sizes=[h_in, w_in],
|
||||
out_sizes=[h_out, w_out],
|
||||
pooling_fn=ops.add,
|
||||
)
|
||||
|
||||
ones_loader = pad_adaptive_loader(ones_like(x))
|
||||
@ -4139,60 +4209,6 @@ def _adaptive_avg_pool2d(x, output_size):
|
||||
return rv
|
||||
|
||||
|
||||
def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader):
|
||||
# NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d
|
||||
# Look into refactoring/deduplication after #116418 is merged.
|
||||
h_in, w_in = in_sizes
|
||||
h_out, w_out = out_sizes
|
||||
|
||||
def start_index(index, out_dim, inp_dim):
|
||||
return FloorDiv((index * inp_dim), out_dim)
|
||||
|
||||
def end_index(index, out_dim, inp_dim):
|
||||
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
||||
|
||||
h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
|
||||
h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
|
||||
w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
|
||||
w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
|
||||
|
||||
def fn_max(idx):
|
||||
*prefix, bh, bw = idx
|
||||
|
||||
h_start_index = h_start_index_fn(bh)
|
||||
h_end_index = h_end_index_fn(bh)
|
||||
|
||||
w_start_index = w_start_index_fn(bw)
|
||||
w_end_index = w_end_index_fn(bw)
|
||||
maxval = None
|
||||
maxindex = None
|
||||
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
||||
val = loader(
|
||||
prefix,
|
||||
[ih, iw],
|
||||
[h_start_index, w_start_index],
|
||||
[h_end_index, w_end_index],
|
||||
)
|
||||
index = ops.index_expr(
|
||||
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
|
||||
)
|
||||
if return_index:
|
||||
if maxindex is None:
|
||||
maxindex = index
|
||||
else:
|
||||
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
|
||||
if maxval is None:
|
||||
maxval = val
|
||||
else:
|
||||
maxval = ops.maximum(val, maxval)
|
||||
if return_index:
|
||||
return maxindex
|
||||
else:
|
||||
return maxval
|
||||
|
||||
return fn_max
|
||||
|
||||
|
||||
fallback_adaptive_max_pool2d = fallback_handler(
|
||||
aten.adaptive_max_pool2d.default, add_to_fallback_set=False
|
||||
)
|
||||
@ -4245,32 +4261,46 @@ def adaptive_max_pool2d(x, output_size):
|
||||
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
||||
return fallback_adaptive_max_pool2d(x, output_size)
|
||||
|
||||
inner_func_max_val = _adaptive_pooling_idx_max(
|
||||
def start_index(index, out_dim, inp_dim):
|
||||
return FloorDiv((index * inp_dim), out_dim)
|
||||
|
||||
def end_index(index, out_dim, inp_dim):
|
||||
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
||||
|
||||
inner_func_max_val = _adaptive_pooling_fn(
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||
in_sizes=[h_in, w_in],
|
||||
out_sizes=[h_out, w_out],
|
||||
return_index=False,
|
||||
loader=pad_adaptive_loader(x, float("-inf")),
|
||||
pooling_fn=ops.maximum,
|
||||
)
|
||||
|
||||
inner_func_max_idx = _adaptive_pooling_idx_max(
|
||||
inner_func_max_idx = _adaptive_pooling_fn_with_idx(
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||
in_sizes=[h_in, w_in],
|
||||
out_sizes=[h_out, w_out],
|
||||
return_index=True,
|
||||
loader=pad_adaptive_loader(x, float("-inf")),
|
||||
pooling_fn=ops.maximum,
|
||||
)
|
||||
|
||||
def inner_fn_max_val(idx):
|
||||
return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
|
||||
|
||||
def inner_fn_max_idx(idx):
|
||||
return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
|
||||
|
||||
rv = Pointwise.create(
|
||||
device=x.get_device(),
|
||||
dtype=dtype,
|
||||
inner_fn=inner_func_max_val,
|
||||
inner_fn=inner_fn_max_val,
|
||||
ranges=new_size,
|
||||
)
|
||||
ri = Pointwise.create(
|
||||
device=x.get_device(),
|
||||
dtype=torch.int64,
|
||||
inner_fn=inner_func_max_idx,
|
||||
inner_fn=inner_fn_max_idx,
|
||||
ranges=new_size,
|
||||
)
|
||||
return rv, ri
|
||||
@ -4400,16 +4430,13 @@ def upsample_nearest2d_backward(
|
||||
def end_index(index, out_dim, inp_dim):
|
||||
return start_index((index + 1), out_dim, inp_dim)
|
||||
|
||||
h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
|
||||
h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
|
||||
|
||||
w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
|
||||
w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
|
||||
|
||||
fn_sum = _adaptive_pooling_idx_sum(
|
||||
[h_kernel_max, w_kernel_max],
|
||||
[h_start_index, w_start_index],
|
||||
[h_end_index, w_end_index],
|
||||
fn_sum = _adaptive_pooling_fn(
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||
in_sizes=[inp_h, inp_w],
|
||||
out_sizes=[out_h, out_w],
|
||||
pooling_fn=ops.add,
|
||||
)
|
||||
|
||||
def fn(idx):
|
||||
@ -4761,6 +4788,207 @@ def avg_pool2d_backward(
|
||||
return rv
|
||||
|
||||
|
||||
fallback_avg_pool3d_backward = fallback_handler(
|
||||
aten.avg_pool3d_backward.default, add_to_fallback_set=False
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
|
||||
def avg_pool3d_backward(
|
||||
grad_output,
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override=None,
|
||||
):
|
||||
assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
if not padding:
|
||||
padding = [0, 0, 0]
|
||||
|
||||
assert isinstance(grad_output, TensorBox)
|
||||
assert isinstance(x, TensorBox)
|
||||
assert len(kernel_size) == 3
|
||||
assert len(stride) == 3
|
||||
assert len(padding) == 3
|
||||
assert len(x.get_size()) in (4, 5)
|
||||
|
||||
grad_output.realize_hint()
|
||||
|
||||
*batch, depth, height, width = x.get_size()
|
||||
|
||||
d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode)
|
||||
h_out, ceil_mode_h = pooling_size(
|
||||
height, 1, kernel_size, stride, padding, ceil_mode
|
||||
)
|
||||
w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode)
|
||||
|
||||
grad_loader = grad_output.make_loader()
|
||||
had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
|
||||
|
||||
*_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
|
||||
new_size = list(x.get_size())
|
||||
dtype = x.get_dtype()
|
||||
|
||||
d_window_size, h_window_size, w_window_size = (
|
||||
max(
|
||||
max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
|
||||
for d in range(kernel_size[i] * 2)
|
||||
)
|
||||
for i in range(3)
|
||||
)
|
||||
|
||||
window_size = d_window_size * h_window_size * w_window_size
|
||||
if window_size > 125:
|
||||
# Kernel size too big. Results in hard-to-optimize Triton code.
|
||||
return fallback_avg_pool3d_backward(
|
||||
grad_output,
|
||||
x,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
)
|
||||
|
||||
def compute_pool_size_without_padding(pd, ph, pw):
|
||||
stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
|
||||
pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
|
||||
kernel_d, kernel_h, kernel_w = (
|
||||
ops.constant(k, torch.int32) for k in kernel_size
|
||||
)
|
||||
|
||||
dstart, hstart, wstart = (
|
||||
ops.sub(ops.mul(p, s), pad)
|
||||
for p, s, pad in zip(
|
||||
[pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
|
||||
)
|
||||
)
|
||||
dend, hend, wend = (
|
||||
ops.minimum(
|
||||
ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
|
||||
)
|
||||
for start, k, dim, pad in zip(
|
||||
[dstart, hstart, wstart],
|
||||
[kernel_d, kernel_h, kernel_w],
|
||||
[depth, height, width],
|
||||
[pad_d, pad_h, pad_w],
|
||||
)
|
||||
)
|
||||
dstart, hstart, wstart = (
|
||||
ops.maximum(start, ops.constant(0, torch.int32))
|
||||
for start in [dstart, hstart, wstart]
|
||||
)
|
||||
dend, hend, wend = (
|
||||
ops.minimum(end, ops.index_expr(dim, torch.int32))
|
||||
for end, dim in zip([dend, hend, wend], [depth, height, width])
|
||||
)
|
||||
divide_factor = ops.mul(
|
||||
ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
|
||||
)
|
||||
return divide_factor
|
||||
|
||||
def fn(idx):
|
||||
*prefix, d, h, w = idx
|
||||
d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
|
||||
|
||||
pdstart, phstart, pwstart = (
|
||||
ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
|
||||
for v, k, s in zip([d, h, w], kernel_size, stride)
|
||||
)
|
||||
|
||||
pdend, phend, pwend = (
|
||||
ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
|
||||
for v, s in zip([d, h, w], stride)
|
||||
)
|
||||
|
||||
pdstart, phstart, pwstart = (
|
||||
ops.maximum(pstart, ops.constant(0, torch.int32))
|
||||
for pstart in [pdstart, phstart, pwstart]
|
||||
)
|
||||
pdend, phend, pwend = (
|
||||
ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
|
||||
for pend, pooled_dim in zip(
|
||||
[pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
|
||||
)
|
||||
)
|
||||
|
||||
gradient = None
|
||||
# Iterate over the 3D region to accumulate gradients
|
||||
for pd_ in range(d_window_size):
|
||||
for ph_ in range(h_window_size):
|
||||
for pw_ in range(w_window_size):
|
||||
pd, ph, pw = (
|
||||
ops.add(pstart, ops.constant(p_, torch.int32))
|
||||
for pstart, p_ in zip(
|
||||
[pdstart, phstart, pwstart], [pd_, ph_, pw_]
|
||||
)
|
||||
)
|
||||
|
||||
if divisor_override is not None:
|
||||
scale = divisor_override
|
||||
elif count_include_pad or not had_padding:
|
||||
scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
|
||||
else:
|
||||
scale = compute_pool_size_without_padding(pd, ph, pw)
|
||||
|
||||
part = ops.truediv(
|
||||
grad_loader(
|
||||
[
|
||||
*prefix,
|
||||
ops.indirect_indexing(
|
||||
ops.minimum(
|
||||
pd, ops.sub(pdend, ops.constant(1, torch.int32))
|
||||
),
|
||||
pooled_depth,
|
||||
check=False,
|
||||
),
|
||||
ops.indirect_indexing(
|
||||
ops.minimum(
|
||||
ph, ops.sub(phend, ops.constant(1, torch.int32))
|
||||
),
|
||||
pooled_height,
|
||||
check=False,
|
||||
),
|
||||
ops.indirect_indexing(
|
||||
ops.minimum(
|
||||
pw, ops.sub(pwend, ops.constant(1, torch.int32))
|
||||
),
|
||||
pooled_width,
|
||||
check=False,
|
||||
),
|
||||
]
|
||||
),
|
||||
scale,
|
||||
)
|
||||
|
||||
mask = ops.and_(
|
||||
ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
|
||||
ops.lt(pw, pwend),
|
||||
)
|
||||
if gradient is None:
|
||||
gradient = ops.where(
|
||||
mask, part, ops.constant(0.0, torch.float32)
|
||||
)
|
||||
else:
|
||||
gradient = ops.where(mask, ops.add(gradient, part), gradient)
|
||||
assert gradient is not None
|
||||
return gradient
|
||||
|
||||
rv = Pointwise.create(
|
||||
device=grad_output.get_device(),
|
||||
dtype=dtype,
|
||||
inner_fn=fn,
|
||||
ranges=new_size,
|
||||
)
|
||||
return rv
|
||||
|
||||
|
||||
def _validate_reduction_axis(x, axis):
|
||||
size = x.get_size()
|
||||
if isinstance(axis, int):
|
||||
|
@ -67,6 +67,7 @@
|
||||
#include <torch/csrc/cpu/Module.h>
|
||||
#include <torch/csrc/dynamo/init.h>
|
||||
#include <torch/csrc/functorch/init.h>
|
||||
#include <torch/csrc/fx/node.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/pybind.h>
|
||||
#include <torch/csrc/jit/python/init.h>
|
||||
#include <torch/csrc/jit/python/python_ir.h>
|
||||
@ -1602,6 +1603,8 @@ PyObject* initModule() {
|
||||
THPDevice_init(module);
|
||||
THPStream_init(module);
|
||||
THPEvent_init(module);
|
||||
NodeBase_init(module);
|
||||
NodeIter_init(module);
|
||||
ASSERT_TRUE(THPVariable_initModule(module));
|
||||
ASSERT_TRUE(THPFunction_initModule(module));
|
||||
ASSERT_TRUE(THPEngine_initModule(module));
|
||||
|
257
torch/csrc/fx/node.cpp
Normal file
257
torch/csrc/fx/node.cpp
Normal file
@ -0,0 +1,257 @@
|
||||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
|
||||
////////////////////////////////
|
||||
// NodeBase
|
||||
///////////////////////////////
|
||||
|
||||
struct NodeBase {
|
||||
PyObject_HEAD bool _erased;
|
||||
NodeBase* _prev;
|
||||
NodeBase* _next;
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwds) {
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
return self;
|
||||
}
|
||||
|
||||
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
self->_erased = false;
|
||||
Py_INCREF(self);
|
||||
self->_prev = self;
|
||||
Py_INCREF(self);
|
||||
self->_next = self;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static struct PyMemberDef NodeBase_members[] = {
|
||||
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
|
||||
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
|
||||
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->_prev);
|
||||
Py_VISIT(self->_next);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->_prev);
|
||||
Py_CLEAR(self->_next);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeBase_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
(void)NodeBase_clear((NodeBase*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
|
||||
sizeof(NodeBase), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)NodeBase_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
||||
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
(traverseproc)NodeBase_traverse, /* tp_traverse */
|
||||
(inquiry)NodeBase_clear, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)NodeBase_init_fn, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
NodeBase_new, /* tp_new */
|
||||
};
|
||||
|
||||
bool NodeBase_init(PyObject* module) {
|
||||
if (PyModule_AddType(module, &NodeBaseType) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////
|
||||
// NodeIter
|
||||
////////////////////////////////
|
||||
|
||||
struct NodeIter {
|
||||
PyObject_HEAD bool _reversed;
|
||||
NodeBase* _root;
|
||||
NodeBase* _cur;
|
||||
};
|
||||
|
||||
static PyObject* NodeIter_new(
|
||||
PyTypeObject* type,
|
||||
PyObject* args,
|
||||
PyObject* kwds) {
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
return self;
|
||||
}
|
||||
|
||||
static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
|
||||
NodeBase* root = nullptr;
|
||||
bool reversed = false;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
constexpr const char* keywords[] = {"root", "reversed", nullptr};
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args,
|
||||
kwargs,
|
||||
"Ob|",
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<char**>(keywords),
|
||||
&root,
|
||||
&reversed)) {
|
||||
return -1;
|
||||
}
|
||||
self->_reversed = reversed;
|
||||
Py_INCREF(root);
|
||||
self->_root = root;
|
||||
Py_INCREF(root);
|
||||
self->_cur = root;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <bool reversed>
|
||||
PyObject* NodeIter_iternext_helper(NodeIter* self) {
|
||||
// It should be possible to relax the ref counting here
|
||||
// but in practice, we do not have that many _erased Nodes,
|
||||
// so probably not worth it.
|
||||
if constexpr (reversed) {
|
||||
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = prev;
|
||||
} else {
|
||||
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = next;
|
||||
}
|
||||
while (self->_cur != self->_root) {
|
||||
if (!self->_cur->_erased) {
|
||||
Py_INCREF(self->_cur);
|
||||
return (PyObject*)self->_cur;
|
||||
}
|
||||
if constexpr (reversed) {
|
||||
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = prev;
|
||||
} else {
|
||||
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
|
||||
Py_CLEAR(self->_cur);
|
||||
self->_cur = next;
|
||||
}
|
||||
}
|
||||
PyErr_SetNone(PyExc_StopIteration);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* NodeIter_iternext(PyObject* _self) {
|
||||
NodeIter* self = (NodeIter*)_self;
|
||||
if (self->_reversed) {
|
||||
return NodeIter_iternext_helper<true>(self);
|
||||
} else {
|
||||
return NodeIter_iternext_helper<false>(self);
|
||||
}
|
||||
}
|
||||
|
||||
static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->_root);
|
||||
Py_VISIT(self->_cur);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int NodeIter_clear(NodeIter* self) {
|
||||
Py_CLEAR(self->_root);
|
||||
Py_CLEAR(self->_cur);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeIter_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
(void)NodeIter_clear((NodeIter*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
static PyTypeObject NodeIterType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
|
||||
sizeof(NodeIter), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
(destructor)NodeIter_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
(traverseproc)NodeIter_traverse, /* tp_traverse */
|
||||
(inquiry)NodeIter_clear, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
PyObject_SelfIter, /* tp_iter */
|
||||
NodeIter_iternext, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc)NodeIter_init_fn, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
NodeIter_new, /* tp_new */
|
||||
};
|
||||
|
||||
bool NodeIter_init(PyObject* module) {
|
||||
if (PyModule_AddType(module, &NodeIterType) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
6
torch/csrc/fx/node.h
Normal file
6
torch/csrc/fx/node.h
Normal file
@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
bool NodeBase_init(PyObject* module);
|
||||
bool NodeIter_init(PyObject* module);
|
@ -86,6 +86,7 @@ def _insert_copy_for_mutations(
|
||||
assert len(outputs) == len(mutated_outputs)
|
||||
|
||||
user_output_nodes = []
|
||||
return_nodes_to_copy = {}
|
||||
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
|
||||
if mutated_node_name is None:
|
||||
user_output_nodes.append(return_node)
|
||||
@ -101,13 +102,18 @@ def _insert_copy_for_mutations(
|
||||
)
|
||||
|
||||
with gm.graph.inserting_before(output_node):
|
||||
_ = gm.graph.call_function(
|
||||
copy_node = gm.graph.call_function(
|
||||
torch.ops.aten.copy_.default, (mutated_node, return_node)
|
||||
)
|
||||
return_nodes_to_copy[return_node] = copy_node
|
||||
|
||||
output_args = [
|
||||
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
|
||||
for node in user_output_nodes
|
||||
]
|
||||
with gm.graph.inserting_before(output_node):
|
||||
# Only return user outputs
|
||||
new_output = gm.graph.output(tuple(user_output_nodes))
|
||||
new_output = gm.graph.output(tuple(output_args))
|
||||
output_node.replace_all_uses_with(new_output)
|
||||
gm.graph.erase_node(output_node)
|
||||
|
||||
|
@ -453,20 +453,21 @@ def free_unbacked_symbols(x):
|
||||
# setup!
|
||||
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
|
||||
if (
|
||||
node.op == "placeholder" and
|
||||
"val" in node.meta and
|
||||
isinstance(node.meta["val"], torch.SymInt) and
|
||||
isinstance(node.meta["val"].node.expr, sympy.Symbol)
|
||||
isinstance(node.meta["val"].node.expr, sympy.Symbol) and
|
||||
(node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr))
|
||||
):
|
||||
return node.meta["val"].node.expr
|
||||
return None
|
||||
|
||||
def find_symbol_binding_fx_nodes(graph):
|
||||
return {
|
||||
node.meta["val"].node.expr: node
|
||||
for node in graph.nodes
|
||||
if is_symbol_binding_fx_node(node)
|
||||
}
|
||||
r = {}
|
||||
# NB: Prefer first occurrence of symbol
|
||||
for node in graph.nodes:
|
||||
if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r:
|
||||
r[node.meta["val"].node.expr] = node
|
||||
return r
|
||||
|
||||
|
||||
# Analogous to ConvertIntSource
|
||||
|
@ -4,6 +4,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na
|
||||
import torch.utils._pytree as pytree
|
||||
from . import _pytree as fx_pytree
|
||||
from ._compatibility import compatibility
|
||||
from torch._C import _NodeIter
|
||||
|
||||
import os
|
||||
import contextlib
|
||||
@ -271,20 +272,8 @@ class _node_list:
|
||||
return self.graph._len
|
||||
|
||||
def __iter__(self):
|
||||
root = self.graph._root
|
||||
if self.direction == "_next":
|
||||
cur = root._next
|
||||
while cur is not root:
|
||||
if not cur._erased:
|
||||
yield cur
|
||||
cur = cur._next
|
||||
else:
|
||||
assert self.direction == "_prev"
|
||||
cur = root._prev
|
||||
while cur is not root:
|
||||
if not cur._erased:
|
||||
yield cur
|
||||
cur = cur._prev
|
||||
assert self.direction == "_prev" or self.direction == "_next"
|
||||
yield from _NodeIter(self.graph._root, self.direction == "_prev")
|
||||
|
||||
def __reversed__(self):
|
||||
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
|
||||
|
@ -11,6 +11,7 @@ import inspect
|
||||
import warnings
|
||||
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
|
||||
from .._ops import ops as _ops
|
||||
from torch._C import _NodeBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str:
|
||||
return str(arg)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Node:
|
||||
class Node(_NodeBase):
|
||||
"""
|
||||
``Node`` is the data structure that represents individual operations within
|
||||
a ``Graph``. For the most part, Nodes represent callsites to various entities,
|
||||
@ -197,6 +198,7 @@ class Node:
|
||||
annotation of values in the generated code or for other types
|
||||
of analyses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.graph = graph
|
||||
self.name = name # unique name of value being created
|
||||
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
|
||||
@ -235,9 +237,6 @@ class Node:
|
||||
# does not produce a value, it's more of a notation. Thus, this value
|
||||
# describes the type of args[0] in the ``return`` node.
|
||||
self.type : Optional[Any] = return_type
|
||||
self._prev = self
|
||||
self._next = self
|
||||
self._erased = False
|
||||
self._sort_key: Any = ()
|
||||
|
||||
# If set, use this fn to print this node
|
||||
@ -247,6 +246,22 @@ class Node:
|
||||
# transformations. This metadata is preserved across node copies
|
||||
self.meta : Dict[str, Any] = {}
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["_erased"] = self._erased
|
||||
state["_prev"] = self._prev
|
||||
state["_next"] = self._next
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
_erased = state.pop("_erased")
|
||||
_prev = state.pop("_prev")
|
||||
_next = state.pop("_next")
|
||||
self.__dict__.update(state)
|
||||
self._erased = _erased
|
||||
self._prev = _prev
|
||||
self._next = _next
|
||||
|
||||
@property
|
||||
def next(self) -> 'Node':
|
||||
"""
|
||||
|
Reference in New Issue
Block a user