Refactored proxytensor to clean up separate branches (#84325)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84325
Approved by: https://github.com/ezyang
This commit is contained in:
Horace He
2022-08-31 07:01:37 +00:00
committed by PyTorch MergeBot
parent 8843f5b986
commit a27a4a02fe
6 changed files with 463 additions and 464 deletions

View File

@ -3177,8 +3177,8 @@ class TestFunctionalize(TestCase):
def forward(self, x_1, indices_1) -> torch.Tensor:
index_tensor = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None
return index_tensor
index = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None
return index
""")
# Ensure grad(functionalize(f)) works
@ -3247,11 +3247,11 @@ def forward(self, x_1, indices_1) -> torch.Tensor:
def forward(self, x_1) -> torch.Tensor:
ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
view_copy_default = torch.ops.aten.view_copy.default(x_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
copy__default = torch.ops.aten.copy_.default(x_1, view_copy_default_1); x_1 = None
return view_copy_default_1
view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = None
return view_copy_1
""")
def test_functionalize_fx_transpose_simple(self, device):
@ -3266,8 +3266,8 @@ def forward(self, x_1) -> torch.Tensor:
def forward(self, x_1) -> torch.Tensor:
transpose_copy_int = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None
return transpose_copy_int
transpose_copy = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None
return transpose_copy
""")
def test_functionalize_fx_out_op(self, device):
@ -3288,12 +3288,12 @@ def forward(self, x_1) -> torch.Tensor:
def forward(self, inpt_1) -> torch.Tensor:
empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4])
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4]); add_tensor = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor_1, [4]); add_tensor_1 = None
return view_copy_default_2
add = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [4])
view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None
return view_copy_2
""")
def test_functionalize_fx_multi_out_op(self, device):
@ -3316,13 +3316,13 @@ def forward(self, inpt_1) -> torch.Tensor:
def forward(self, inpt_1) -> torch.Tensor:
empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False)
empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False)
view_copy_default = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None
aminmax_default = torch.ops.aten.aminmax.default(view_copy_default_1, dim = 0); view_copy_default_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
view_copy_default_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None
return (view_copy_default_2, getitem)
view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None
aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0); view_copy_1 = None
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None
return (view_copy_2, getitem)
""")
def test_functionalize_fx_reapply_views_simple(self, device):
@ -3341,11 +3341,11 @@ def forward(self, inpt_1) -> torch.Tensor:
def forward(self, x_1) -> torch.Tensor:
ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
view_default = torch.ops.aten.view.default(x_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None
copy__default = torch.ops.aten.copy_.default(x_1, view_default_1); x_1 = None
return view_default_1
view = torch.ops.aten.view.default(x_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = None
return view_1
""")
def test_functionalize_nonfunctional_output(self, device):
@ -3382,8 +3382,8 @@ def forward(self) -> torch.Tensor:
def forward(self, a_1, b_1) -> torch.Tensor:
index_tensor = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None
return index_tensor
index = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None
return index
""")
def test_functionalize_optional_tensorlist2(self, device):
@ -3400,11 +3400,11 @@ def forward(self, a_1, b_1) -> torch.Tensor:
def forward(self, a_1, b_1) -> torch.Tensor:
unbind_int = torch.ops.aten.unbind.int(b_1); b_1 = None
getitem = unbind_int[0]
getitem_1 = unbind_int[1]; unbind_int = None
index_tensor = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None
return index_tensor
unbind = torch.ops.aten.unbind.int(b_1); b_1 = None
getitem = unbind[0]
getitem_1 = unbind[1]; unbind = None
index = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None
return index
""")
def test_resize_program_inputs(self, device):
@ -3420,10 +3420,10 @@ def forward(self, a_1, b_1) -> torch.Tensor:
def forward(self, x_1):
resize_default = torch.ops.aten.resize.default(x_1, [10])
fill_scalar = torch.ops.aten.fill.Scalar(resize_default, 2); resize_default = None
resize__default = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None
copy__default = torch.ops.aten.copy_.default(resize__default, fill_scalar); resize__default = fill_scalar = None
resize = torch.ops.aten.resize.default(x_1, [10])
fill = torch.ops.aten.fill.Scalar(resize, 2); resize = None
resize_ = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None
copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = None
return None
""")

View File

@ -117,12 +117,12 @@ class TestFunctionalization(TestCase):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1)
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
return add_tensor
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
return add
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
@ -132,12 +132,12 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_default = torch.ops.aten.view.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
return add_tensor
view = torch.ops.aten.view.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None
return add
""")
def test_simple_out(self):
@ -157,11 +157,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
mul = torch.ops.aten.mul.Tensor(add, add); add = None
return mul
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
@ -171,11 +171,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
view = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
mul = torch.ops.aten.mul.Tensor(add, add); add = None
return mul
""")
def test_multi_out(self):
@ -195,9 +195,9 @@ def forward(self, a_1):
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
return getitem
""")
@ -209,9 +209,9 @@ def forward(self, a_1):
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax[0]
getitem_1 = aminmax[1]; aminmax = None
return getitem
""")
@ -232,11 +232,11 @@ def forward(self, a_1):
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view_copy_default = torch.ops.aten.view_copy.default(lift_fresh, [-1]); lift_fresh = None
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [3]); add_tensor = None
return view_copy_default_1
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None
return view_copy_1
""")
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
@ -246,11 +246,11 @@ def forward(self, a_1):
def forward(self, a_1):
_tensor_constant0 = self._tensor_constant0
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view_default = torch.ops.aten.view.default(lift_fresh, [-1]); lift_fresh = None
add_tensor = torch.ops.aten.add_.Tensor(view_default, 1)
view_default_1 = torch.ops.aten.view.default(view_default, [3]); view_default = None
return view_default_1
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
add = torch.ops.aten.add_.Tensor(view, 1)
view_1 = torch.ops.aten.view.default(view, [3]); view = None
return view_1
""")
@ -282,11 +282,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(a_1, ones); ones = None
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
return view_copy_default_1
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(a_1, ones); ones = None
copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
return view_copy_1
""")
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
@ -296,11 +296,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_default = torch.ops.aten.view.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(a_1, ones); ones = None
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None
return view_default_1
view = torch.ops.aten.view.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(a_1, ones); ones = None
copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
return view_1
""")
# Some ops that are mutable are neither inplace nor out= ops.
@ -315,14 +315,14 @@ def forward(self, a_1):
def forward(self, a_1):
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
copy__default = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
getitem = _fused_moving_avg_obs_fq_helper_functional[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]
getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]
getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None
copy_ = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
return (getitem, getitem_1)
""") # noqa: B950
@ -338,11 +338,11 @@ def forward(self, a_1):
def forward(self, a_1):
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); add_tensor = None
copy__default = torch.ops.aten.copy_.default(a_1, as_strided_scatter_default); a_1 = None
return as_strided_scatter_default
as_strided_copy = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(a_1, add, [2], [2], 1); add = None
copy_ = torch.ops.aten.copy_.default(a_1, as_strided_scatter); a_1 = None
return as_strided_scatter
""")
def test_tensor_list_composite(self):
@ -357,8 +357,8 @@ def forward(self, a_1):
def forward(self, a_1):
block_diag_default = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
return block_diag_default
block_diag = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
return block_diag
""")
def test_cat(self):
@ -374,8 +374,8 @@ def forward(self, a_1):
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat_default
cat = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
@ -385,8 +385,8 @@ def forward(self, a_1):
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat_default
cat = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat
""")
@ -406,11 +406,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
clone_default = torch.ops.aten.clone.default(a_1)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(clone_default); clone_default = None
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul_tensor
clone = torch.ops.aten.clone.default(a_1)
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone); clone = None
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
@ -420,11 +420,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
clone_default = torch.ops.aten.clone.default(a_1)
diagonal_default = torch.ops.aten.diagonal.default(clone_default); clone_default = None
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, ones); diagonal_default = ones = None
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul_tensor
clone = torch.ops.aten.clone.default(a_1)
diagonal = torch.ops.aten.diagonal.default(clone); clone = None
add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None
mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
return mul
""")
def test_diagonal_mutated_input(self):
@ -443,11 +443,11 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); add_tensor = None
copy__default = torch.ops.aten.copy_.default(a_1, diagonal_scatter_default); a_1 = None
return diagonal_scatter_default
diagonal_copy = torch.ops.aten.diagonal_copy.default(a_1)
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(a_1, add); add = None
copy_ = torch.ops.aten.copy_.default(a_1, diagonal_scatter); a_1 = None
return diagonal_scatter
""")
def test_split(self):
@ -467,19 +467,19 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem = split_copy_tensor[0]
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None
split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem_2 = split_copy_tensor_1[0]
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); diagonal_scatter_default = None
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default)
copy__default = torch.ops.aten.copy_.default(a_1, slice_scatter_default); a_1 = slice_scatter_default = None
return add_tensor
split_copy = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
split_copy_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem_2 = split_copy_1[0]
getitem_3 = split_copy_1[1]; split_copy_1 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = None
slice_scatter = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(a_1, slice_scatter); a_1 = slice_scatter = None
return add
""") # noqa: B950
def test_view_inplace(self):
@ -498,14 +498,14 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None
add_tensor = torch.ops.aten.add.Tensor(select_copy_int, ones); select_copy_int = ones = None
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None
transpose_copy_int_3 = torch.ops.aten.transpose_copy.int(transpose_copy_int_2, 1, 0); transpose_copy_int_2 = None
return transpose_copy_int_3
transpose_copy = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None
add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_3
""") # noqa: B950
def test_optional_tensor_list(self):
@ -524,13 +524,13 @@ def forward(self, a_1):
def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8])
view_copy = torch.ops.aten.view_copy.default(a_1, [8])
arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
return index_put_default
index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2])
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
return index_put
""") # noqa: B950
def test_scalars(self):
@ -550,13 +550,13 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
return div_tensor
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
mul = torch.ops.aten.mul.Tensor(add, 2)
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
return div
""")
@skipIfTorchDynamo("Test does not work with TorchDynamo")
@ -574,10 +574,10 @@ def forward(self, a_1):
def forward(self, a_1):
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
return _to_copy_default
clone = torch.ops.aten.clone.default(a_1); a_1 = None
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
return _to_copy
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
@ -586,10 +586,10 @@ def forward(self, a_1):
def forward(self, a_1):
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
return _to_copy_default
clone = torch.ops.aten.clone.default(a_1); a_1 = None
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
return _to_copy
""") # noqa: B950
@skipIfTorchDynamo("Test does not work with TorchDynamo")
@ -622,8 +622,8 @@ def forward(self, a_1):
def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
return view_copy_default
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
return view_copy
""")
def test_everything(self):
@ -648,35 +648,35 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8])
_reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None
transpose_copy_int = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default, 1, 0)
unsqueeze_copy_default = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int, 0); transpose_copy_int = None
squeeze_copy_default = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default); unsqueeze_copy_default = None
split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None
getitem = split_copy_tensor[0]
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
add_tensor_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None
clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format)
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [8]); add_tensor = None
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_1, [2, 4], [4, 1]); view_copy_default_1 = None
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default_1, 1, 0); _reshape_alias_copy_default_1 = None
unsqueeze_copy_default_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int_1, 0); transpose_copy_int_1 = None
squeeze_copy_default_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default_1); unsqueeze_copy_default_1 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(squeeze_copy_default_1, add_tensor_1, 0, 0, 2); squeeze_copy_default_1 = None
unsqueeze_copy_default_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter_default, 0); slice_scatter_default = None
squeeze_copy_dim = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_default_2, 0); unsqueeze_copy_default_2 = None
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_dim, 1, 0); squeeze_copy_dim = None
_reshape_alias_copy_default_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_int_2, [8], [1]); transpose_copy_int_2 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_default_2, [4, 2]); _reshape_alias_copy_default_2 = None
view_copy_default_3 = torch.ops.aten.view_copy.default(view_copy_default_2, [8]); view_copy_default_2 = None
_reshape_alias_copy_default_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_3, [2, 4], [4, 1]); view_copy_default_3 = None
select_copy_int_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_default_3, 0, 0); _reshape_alias_copy_default_3 = None
add_tensor_2 = torch.ops.aten.add.Tensor(select_copy_int_1, _unsafe_view_default); select_copy_int_1 = _unsafe_view_default = None
return add_tensor_1
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [8])
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None
transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0)
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
getitem = split_copy[0]
getitem_1 = split_copy[1]; split_copy = None
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
clone = torch.ops.aten.clone.default(add_1, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_1, 1, 0); _reshape_alias_copy_1 = None
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_2, [4, 2]); _reshape_alias_copy_2 = None
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_3, 0, 0); _reshape_alias_copy_3 = None
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _unsafe_view); select_copy_1 = _unsafe_view = None
return add_1
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
@ -686,33 +686,33 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_default = torch.ops.aten.view.default(add_tensor, [8])
_reshape_alias_default = torch.ops.aten._reshape_alias.default(view_default, [2, 4], [4, 1]); view_default = None
transpose_int = torch.ops.aten.transpose.int(_reshape_alias_default, 1, 0)
unsqueeze_default = torch.ops.aten.unsqueeze.default(transpose_int, 0); transpose_int = None
squeeze_default = torch.ops.aten.squeeze.default(unsqueeze_default); unsqueeze_default = None
split_tensor = torch.ops.aten.split.Tensor(squeeze_default, 2); squeeze_default = None
getitem = split_tensor[0]
getitem_1 = split_tensor[1]; split_tensor = None
add_tensor_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
select_int = torch.ops.aten.select.int(_reshape_alias_default, 0, 0); _reshape_alias_default = None
clone_default = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [8]); add_tensor = None
_reshape_alias_default_1 = torch.ops.aten._reshape_alias.default(view_default_1, [2, 4], [4, 1]); view_default_1 = None
transpose_int_1 = torch.ops.aten.transpose.int(_reshape_alias_default_1, 1, 0); _reshape_alias_default_1 = None
unsqueeze_default_1 = torch.ops.aten.unsqueeze.default(transpose_int_1, 0); transpose_int_1 = None
squeeze_default_1 = torch.ops.aten.squeeze.default(unsqueeze_default_1); unsqueeze_default_1 = None
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(squeeze_default_1, 0); squeeze_default_1 = None
squeeze_dim = torch.ops.aten.squeeze.dim(unsqueeze_default_2, 0); unsqueeze_default_2 = None
transpose_int_2 = torch.ops.aten.transpose.int(squeeze_dim, 1, 0); squeeze_dim = None
_reshape_alias_default_2 = torch.ops.aten._reshape_alias.default(transpose_int_2, [8], [1]); transpose_int_2 = None
view_default_2 = torch.ops.aten.view.default(_reshape_alias_default_2, [4, 2]); _reshape_alias_default_2 = None
view_default_3 = torch.ops.aten.view.default(view_default_2, [8]); view_default_2 = None
_reshape_alias_default_3 = torch.ops.aten._reshape_alias.default(view_default_3, [2, 4], [4, 1]); view_default_3 = None
select_int_1 = torch.ops.aten.select.int(_reshape_alias_default_3, 0, 0); _reshape_alias_default_3 = None
add_tensor_2 = torch.ops.aten.add.Tensor(select_int_1, _unsafe_view_default); select_int_1 = _unsafe_view_default = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view = torch.ops.aten.view.default(add, [8])
_reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None
transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0)
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
getitem = split[0]
getitem_1 = split[1]; split = None
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None
clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
view_1 = torch.ops.aten.view.default(add, [8]); add = None
_reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None
transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
_reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None
view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None
view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None
_reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None
select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
return getitem
""")
@ -731,12 +731,12 @@ def forward(self, a_1):
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
view_default = torch.ops.aten.view.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
return add_tensor
view = torch.ops.aten.view.default(a_1, [4, 2])
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
view_1 = torch.ops.aten.view.default(add, [4, 2])
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None
return add
""")
def test_aliases_maintained_after_pass_when_reapplying_views(self):
@ -781,9 +781,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add_tensor
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add
""")
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
@ -793,9 +793,9 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add_tensor
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
return add
""")
# Test 2: copy_() with same dtype, different shape
@ -807,10 +807,10 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
return add_tensor
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
return add
""")
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
@ -820,10 +820,10 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
return expand_copy_default
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
return expand_copy
""")
# Test 3: copy_() with different dtype, same shape
@ -835,10 +835,10 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(_to_copy_default, a_1); _to_copy_default = a_1 = None
return add_tensor
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(_to_copy, a_1); _to_copy = a_1 = None
return add
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
@ -848,10 +848,10 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add_.Tensor(_to_copy_default, a_1); a_1 = None
return _to_copy_default
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add_.Tensor(_to_copy, a_1); a_1 = None
return _to_copy
""") # noqa: B950
# Test 4: copy_() with different dtype, different shape
@ -863,11 +863,11 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
return add_tensor
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
return add
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
@ -877,11 +877,11 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
return expand_copy_default
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
return expand_copy
""") # noqa: B950
def test_expand_symint(self):
@ -897,8 +897,8 @@ def forward(self, a_1):
def forward(self, a_1):
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None
return expand_copy_default
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None
return expand_copy
""")
def test_fill_(self):
@ -915,11 +915,11 @@ def forward(self, a_1):
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(add_tensor)
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
return diagonal_scatter_default
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None
return diagonal_scatter
""")
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
@ -928,10 +928,10 @@ def forward(self, a_1):
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_default = torch.ops.aten.diagonal.default(add_tensor)
fill_scalar = torch.ops.aten.fill_.Scalar(diagonal_default, 0); diagonal_default = None
return add_tensor
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal = torch.ops.aten.diagonal.default(add)
fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None
return add
""")
def test_resize_smaller(self):
@ -952,21 +952,21 @@ def forward(self, a_1):
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4, 4])
resize_default = torch.ops.aten.resize.default(view_copy_default, [3, 3])
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(view_copy_default, [3, 3], [3, 1]); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(as_strided_copy_default, [-1]); as_strided_copy_default = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor, [4, 4]); add_tensor = None
as_strided_copy_default_1 = torch.ops.aten.as_strided_copy.default(view_copy_default_2, [3, 3], [3, 1])
view_copy_default_3 = torch.ops.aten.view_copy.default(add_tensor_1, [3, 3]); add_tensor_1 = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(view_copy_default_2, view_copy_default_3, [3, 3], [3, 1]); view_copy_default_2 = view_copy_default_3 = None
view_copy_default_4 = torch.ops.aten.view_copy.default(as_strided_scatter_default, [8, 2]); as_strided_scatter_default = None
view_copy_default_5 = torch.ops.aten.view_copy.default(view_copy_default_4, [4, 4]); view_copy_default_4 = None
as_strided_copy_default_2 = torch.ops.aten.as_strided_copy.default(view_copy_default_5, [3, 3], [3, 1]); view_copy_default_5 = None
add_tensor_2 = torch.ops.aten.add.Tensor(as_strided_copy_default_2, 1); as_strided_copy_default_2 = None
return add_tensor_2
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
resize = torch.ops.aten.resize.default(view_copy, [3, 3])
as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_2, 1); as_strided_copy_2 = None
return add_2
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
@ -975,20 +975,20 @@ def forward(self, a_1):
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_default = torch.ops.aten.view.default(add_tensor, [4, 4])
resize_default = torch.ops.aten.resize.default(view_default, [3, 3])
as_strided_default = torch.ops.aten.as_strided.default(view_default, [3, 3], [3, 1]); view_default = None
view_default_1 = torch.ops.aten.view.default(as_strided_default, [-1]); as_strided_default = None
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_1, 1)
view_default_2 = torch.ops.aten.view.default(add_tensor, [4, 4]); add_tensor = None
as_strided_default_1 = torch.ops.aten.as_strided.default(view_default_2, [3, 3], [3, 1])
view_default_3 = torch.ops.aten.view.default(view_default_1, [3, 3]); view_default_1 = None
view_default_4 = torch.ops.aten.view.default(view_default_2, [8, 2]); view_default_2 = None
view_default_5 = torch.ops.aten.view.default(view_default_4, [4, 4]); view_default_4 = None
as_strided_default_2 = torch.ops.aten.as_strided.default(view_default_5, [3, 3], [3, 1]); view_default_5 = None
add_tensor_2 = torch.ops.aten.add_.Tensor(as_strided_default_2, 1)
return as_strided_default_2
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view = torch.ops.aten.view.default(add, [4, 4])
resize = torch.ops.aten.resize.default(view, [3, 3])
as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None
view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None
add_1 = torch.ops.aten.add_.Tensor(view_1, 1)
view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None
as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None
view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None
view_5 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None
add_2 = torch.ops.aten.add_.Tensor(as_strided_2, 1)
return as_strided_2
""")
def test_resize_larger_valid(self):
@ -1015,13 +1015,13 @@ def forward(self, a_1):
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize_default = torch.ops.aten.resize.default(add_tensor, [5, 5]); add_tensor = None
view_copy_default = torch.ops.aten.view_copy.default(resize_default, [25]); resize_default = None
fill_scalar = torch.ops.aten.fill.Scalar(view_copy_default, 1); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
return (view_copy_default_1, add_tensor_1)
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize = torch.ops.aten.resize.default(add, [5, 5]); add = None
view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None
fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
return (view_copy_1, add_1)
""")
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
@ -1030,13 +1030,13 @@ def forward(self, a_1):
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize_default = torch.ops.aten.resize_.default(add_tensor, [5, 5])
view_default = torch.ops.aten.view.default(add_tensor, [25]); add_tensor = None
fill_scalar = torch.ops.aten.fill_.Scalar(view_default, 1)
view_default_1 = torch.ops.aten.view.default(view_default, [5, 5]); view_default = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_default_1, 1)
return (view_default_1, add_tensor_1)
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize = torch.ops.aten.resize_.default(add, [5, 5])
view = torch.ops.aten.view.default(add, [25]); add = None
fill = torch.ops.aten.fill_.Scalar(view, 1)
view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None
add_1 = torch.ops.aten.add.Tensor(view_1, 1)
return (view_1, add_1)
""")
def test_resize_larger_invalid(self):
@ -1115,10 +1115,10 @@ $3 = torch._ops.aten.add.Tensor($2, 1)""")
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
select_copy_int = torch.ops.aten.select_copy.int(zeros, 0, 5)
fill_scalar = torch.ops.aten.fill.Scalar(select_copy_int, 1); select_copy_int = None
select_scatter_default = torch.ops.aten.select_scatter.default(zeros, fill_scalar, 0, 5); zeros = fill_scalar = None
return select_scatter_default
select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None
select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None
return select_scatter
""") # noqa: B950
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
@ -1128,8 +1128,8 @@ def forward(self, a_1):
def forward(self, a_1):
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
select_int = torch.ops.aten.select.int(zeros, 0, 5)
fill_scalar = torch.ops.aten.fill_.Scalar(select_int, 1); select_int = None
select = torch.ops.aten.select.int(zeros, 0, 5)
fill = torch.ops.aten.fill_.Scalar(select, 1); select = None
return zeros
""")

View File

@ -30,9 +30,9 @@ class TestReinplacePass(TestCase):
def forward(self, x_1):
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
add_tensor = torch.ops.aten.add_.Tensor(clone_default, 1)
return clone_default
clone = torch.ops.aten.clone.default(x_1); x_1 = None
add = torch.ops.aten.add_.Tensor(clone, 1)
return clone
""")
@ -56,11 +56,11 @@ def forward(self, x_1):
def forward(self, x_1):
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
view_default = torch.ops.aten.view.default(clone_default, [-1])
add_tensor = torch.ops.aten.add.Tensor(clone_default, 1); clone_default = None
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1)
return view_default
clone = torch.ops.aten.clone.default(x_1); x_1 = None
view = torch.ops.aten.view.default(clone, [-1])
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
add_1 = torch.ops.aten.add_.Tensor(view, 1)
return view
""")
def test_reinplace_different_metadata(self):
@ -82,10 +82,10 @@ def forward(self, x_1):
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
add_tensor = torch.ops.aten.add.Tensor(clone_default, 1)
ge_tensor = torch.ops.aten.ge.Tensor(add_tensor, clone_default); add_tensor = clone_default = None
return ge_tensor
clone = torch.ops.aten.clone.default(a__1); a__1 = None
add = torch.ops.aten.add.Tensor(clone, 1)
ge = torch.ops.aten.ge.Tensor(add, clone); add = clone = None
return ge
""")
def test_reinplace_overlapping_memory(self):
@ -105,10 +105,10 @@ def forward(self, a__1):
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
expand_default = torch.ops.aten.expand.default(clone_default, [4, 4]); clone_default = None
add_tensor = torch.ops.aten.add.Tensor(expand_default, 1); expand_default = None
return add_tensor
clone = torch.ops.aten.clone.default(a__1); a__1 = None
expand = torch.ops.aten.expand.default(clone, [4, 4]); clone = None
add = torch.ops.aten.add.Tensor(expand, 1); expand = None
return add
""")
# This test won't actually run in CI, because it requires functionalize() from functorch.
@ -217,13 +217,13 @@ def forward(self, a__1):
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
return as_strided_default
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
return as_strided
""")
# Test example where we have a scatter op, where the base tensor
@ -253,15 +253,15 @@ def forward(self, a__1):
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 0)
copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None
return as_strided_default
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
select_int = torch.ops.aten.select.int(as_strided, 0, 0)
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
return as_strided
""") # noqa: B950
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
@ -286,15 +286,15 @@ def forward(self, a__1):
def forward(self, a__1):
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None
select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 1)
copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None
return as_strided_default
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
select_int = torch.ops.aten.select.int(as_strided, 0, 1)
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
return as_strided
""") # noqa: B950

View File

@ -327,7 +327,7 @@ class TestPrims(TestCase):
for node in gm.graph.nodes:
if node.op == "call_function":
self.assertTrue(node.name == "add_default")
self.assertTrue(node.name == "add")
self.assertTrue(node.target == torch.ops.nvprims.add.default)
self.assertFalse(node.target == torch.ops.prims.add.default)
self.assertFalse(node.target == torch.ops.aten.add.default)

View File

@ -314,8 +314,8 @@ class TestGenericProxyTensor(TestCase):
def forward(self, x_1):
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
copy__default = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
return copy__default
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
return copy_
""")
def test_make_fx_reentrant_dispatch(self):
@ -589,13 +589,19 @@ def forward(self, x_1):
)
def test_trace_subclasses(self):
def f(x):
def f1(x):
x = UnwrapTensor(x)
y = x * 2
return y
def f2(x):
wrapped = UnwrapTensor(x)
y = x * wrapped
return y
inp = [torch.randn(5)]
self._test(f, inp)
self._test(f1, inp)
self._test(f2, inp)
def test_partial_decomp(self):
def f(a, b, c):
@ -616,6 +622,19 @@ def forward(self, x_1):
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
def test_decomp_of_capture(self):
val = torch.randn(5)
def f(x):
return x.t() + val.t()
def nop(x):
return x.cos()
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_amp_cache(self):
layer = torch.nn.Conv2d(3, 3, 3).cuda()
@ -771,9 +790,9 @@ def forward(self, a_1):
mul = sym_size * 2; sym_size = None
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
sym_size_1 = torch.ops.aten.sym_size(empty, 0)
detach_default = torch.ops.aten.detach.default(empty); empty = None
sym_size_2 = torch.ops.aten.sym_size(detach_default, 0)
return detach_default""")
detach = torch.ops.aten.detach.default(empty); empty = None
sym_size_2 = torch.ops.aten.sym_size(detach, 0)
return detach""")
def test_cat(self):
def f(a, b):

View File

@ -182,24 +182,29 @@ def fetch_symint_proxy(tracer):
def fetch_tensor_proxy(tracer):
return lambda t: get_proxy_slot(t, tracer, t)
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
def proxy_call(proxy_mode, func_overload, args, kwargs=None):
if kwargs is None:
kwargs = {}
def proxy_call(proxy_mode, func, args, kwargs):
def can_handle_tensor(x):
return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
func = func_overload.overloadpacket
if func_overload in CURRENT_DECOMPOSITION_TABLE:
# If there are any tensor subclasses, we need to handle those tensor subclasses first
# TODO: we could use types to test this
if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)):
return NotImplemented
if func in CURRENT_DECOMPOSITION_TABLE:
with proxy_mode.restore():
r = CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
if r is not NotImplemented:
return r
# Some of these are not "real" aten ops and will fail if we
# call _dispatch_has_kernel_for_dispatch_key on them.
# This list is probably incomplete
if func_overload not in [torch.ops.aten.size.default]:
if func not in [torch.ops.aten.size.default]:
with proxy_mode.restore():
r = func_overload.decompose(*args, **kwargs)
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
@ -217,14 +222,14 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
and pytree.tree_all_only(SymInt, lambda _: False, (args, kwargs))
)
if torch.Tag.data_dependent_output in func_overload.tags: # type: ignore[attr-defined]
if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined]
# Check if all of the Tensor inputs are constants
if all_constant:
const_args, const_kwargs = pytree.tree_map_only(
_ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
)
with maybe_disable_fake_tensor_mode():
return func_overload(*const_args, **const_kwargs)
return func(*const_args, **const_kwargs)
raise RuntimeError(
"It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar."
@ -235,21 +240,59 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
fetch_symint_proxy(proxy_mode.tracer),
pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
)
proxy_out = func_overload(*proxy_args, **proxy_kwargs)
# When we trace through a torch.tensor invocation, you never actually
# see a torch.ops.aten.tensor call. Instead, the way this function is
# implemented internally is that we allocate a plain tensor (this is
# *guaranteed* to be a plain tensor, we disable all modes when doing
# so), and then call at::lift_fresh on it (to give modes a chance to do
# their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed
# to be freshly allocated, so we want lift_fresh to be a no-op (directly
# returning the input argument).
#
# Here is the basic problem: when we trace this sequence of executions
# into an FX graph, what happens to this call sequence? Traditionally,
# tensor constants get interned as buffers on the FX GraphModule. But
# this is dangerous. Consider:
#
# x = torch.tensor(1)
# x.add_(2)
#
# Naively, this traces into:
#
# t = self._tensor_constant0 # initialized to torch.tensor(1)
# x = torch.ops.aten.lift_fresh(t)
# x.add_(2)
#
# If lift_fresh returns t directly, the subsequent add_ call will
# modify the tensor constant. Really, the problem is we've violated
# the invariant the the argument to lift is fresh. So what we should
# preserve the invariant by replacing lift_fresh with lift_fresh_copy:
#
# t = self._tensor_constant0 # initialized to torch.tensor(1)
# x = torch.ops.aten.lift_fresh_copy(t)
# x.add_(2)
#
# This is what the overload modification does.
if func is torch.ops.aten.lift_fresh.default:
func = torch.ops.aten.lift_fresh_copy.default
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
# This makes DCE marginally less likely to DCE inplace operations.
# It is not strictly necessary
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_":
if isinstance(args[0], List):
# e.g., c10d::allreduce_ returns a list of tensors as the first element
# in the output.
for i, a in enumerate(args[0]):
a.proxy = proxy_out[0][i]
else:
# This makes DCE marginally less likely to DCE inplace operations.
# It is not strictly necessary
args[0].proxy = proxy_out
out = func_overload(*args, **kwargs)
out = func(*args, **kwargs)
# In some circumstances, we will be tracing in a situation where a tensor
# is *statically* known to be a constant (currently, this only happens if
@ -275,18 +318,27 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
constant = None
# NB: do NOT include factories as constants
if (
torch.Tag.nondeterministic_seeded not in func_overload.tags # type: ignore[attr-defined]
# If this is a lift, the input tensor is guaranteed to be a
# constant, so we keep a copy of the original argument along so
# we can query it if we're asked to item() it at some later point
if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT:
with maybe_disable_fake_tensor_mode():
constant = args[0].clone()
elif (
torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined]
and all_constant
and any_constant
and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out)
):
# NB: do NOT include factories as constants
with maybe_disable_fake_tensor_mode():
const_args, const_kwargs = pytree.tree_map_only(
_ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
)
constant = func_overload(*const_args, **const_kwargs)
constant = func(*const_args, **const_kwargs)
else:
constant = None
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
return out
@ -367,9 +419,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
self.sym_mode = ProxySymDispatchMode(tracer)
self.trace_state = {}
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with self.sym_mode.enable(False):
return self.inner_torch_dispatch(func_overload, types, args, kwargs)
return self.inner_torch_dispatch(func, types, args, kwargs)
@contextmanager
def restore(self):
@ -377,90 +429,18 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
with super().restore():
yield
def inner_torch_dispatch(self, func_overload, types, args=(), kwargs=None):
def inner_torch_dispatch(self, func, types, args=(), kwargs=None):
if not self.enable_tracing:
return func_overload(*args, **kwargs)
return func(*args, **kwargs)
if symbolic_shapes.is_symbolic_op(func_overload):
if symbolic_shapes.is_symbolic_op(func):
with self.restore():
return symbolic_shapes.handle_symbolic_op(func_overload, args, kwargs)
return symbolic_shapes.handle_symbolic_op(func, args, kwargs)
func = func_overload.overloadpacket
# We don't want to convert torch.tensor constants into tracing objects.
if func_overload == aten.lift.default:
return args[0]
if func in [prim.device.default]:
return func(*args, **kwargs)
if func in [prim.device]:
return func_overload(*args, **kwargs)
if pytree.tree_any_only(
torch.Tensor,
lambda t: has_proxy_slot(t, self.tracer),
(args, kwargs)
):
out = proxy_call(self, func_overload, args, kwargs)
# When we trace through a torch.tensor invocation, you never actually
# see a torch.ops.aten.tensor call. Instead, the way this function is
# implemented internally is that we allocate a plain tensor (this is
# *guaranteed* to be a plain tensor, we disable all modes when doing
# so), and then call at::lift_fresh on it (to give modes a chance to do
# their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed
# to be freshly allocated, so we want lift_fresh to be a no-op (directly
# returning the input argument).
#
# Here is the basic problem: when we trace this sequence of executions
# into an FX graph, what happens to this call sequence? Traditionally,
# tensor constants get interned as buffers on the FX GraphModule. But
# this is dangerous. Consider:
#
# x = torch.tensor(1)
# x.add_(2)
#
# Naively, this traces into:
#
# t = self._tensor_constant0 # initialized to torch.tensor(1)
# x = torch.ops.aten.lift_fresh(t)
# x.add_(2)
#
# If lift_fresh returns t directly, the subsequent add_ call will
# modify the tensor constant. Really, the problem is we've violated
# the invariant the the argument to lift is fresh. So what we should
# preserve the invariant by replacing lift_fresh with lift_fresh_copy:
#
# t = self._tensor_constant0 # initialized to torch.tensor(1)
# x = torch.ops.aten.lift_fresh_copy(t)
# x.add_(2)
#
# This is what the overload modification does.
else:
flat_args = pytree.tree_flatten((args, kwargs))[0]
handled_types = [torch.Tensor, _ProxyTensor, torch.nn.Parameter]
# If there are any tensor subclasses, we need to handle those tensor subclasses first
# TODO: we could use types to test this
if any(isinstance(arg, torch.Tensor) and type(arg) not in handled_types for arg in flat_args):
return NotImplemented
if func_overload is torch.ops.aten.lift_fresh.default:
func_overload = torch.ops.aten.lift_fresh_copy.default
n_args, n_kwargs = pytree.tree_map_only(SymInt, fetch_symint_proxy(self.tracer), (args, kwargs))
proxy_out = self.tracer.create_proxy('call_function', func_overload, n_args, n_kwargs,
name=self.tracer.graph._target_to_str(func.__name__))
out = func_overload(*args, **kwargs)
# If this is a lift, the input tensor is guaranteed to be a
# constant, so we keep a copy of the original argument along so
# we can query it if we're asked to item() it at some later point
is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default
if is_lift and out.numel() <= CONSTANT_NUMEL_LIMIT:
with maybe_disable_fake_tensor_mode():
constant = args[0].clone()
else:
constant = None
track_tensor_tree(out, proxy_out, constant=constant, tracer=self.tracer)
out = proxy_call(self, func, args, kwargs)
def assert_proxy_tensor(e):
assert has_proxy_slot(e, self.tracer), \