mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactored proxytensor to clean up separate branches (#84325)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84325 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8843f5b986
commit
a27a4a02fe
@ -3177,8 +3177,8 @@ class TestFunctionalize(TestCase):
|
||||
|
||||
|
||||
def forward(self, x_1, indices_1) -> torch.Tensor:
|
||||
index_tensor = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None
|
||||
return index_tensor
|
||||
index = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None
|
||||
return index
|
||||
""")
|
||||
|
||||
# Ensure grad(functionalize(f)) works
|
||||
@ -3247,11 +3247,11 @@ def forward(self, x_1, indices_1) -> torch.Tensor:
|
||||
|
||||
def forward(self, x_1) -> torch.Tensor:
|
||||
ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
|
||||
view_copy_default = torch.ops.aten.view_copy.default(x_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||
copy__default = torch.ops.aten.copy_.default(x_1, view_copy_default_1); x_1 = None
|
||||
return view_copy_default_1
|
||||
view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
|
||||
copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = None
|
||||
return view_copy_1
|
||||
""")
|
||||
|
||||
def test_functionalize_fx_transpose_simple(self, device):
|
||||
@ -3266,8 +3266,8 @@ def forward(self, x_1) -> torch.Tensor:
|
||||
|
||||
|
||||
def forward(self, x_1) -> torch.Tensor:
|
||||
transpose_copy_int = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None
|
||||
return transpose_copy_int
|
||||
transpose_copy = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None
|
||||
return transpose_copy
|
||||
""")
|
||||
|
||||
def test_functionalize_fx_out_op(self, device):
|
||||
@ -3288,12 +3288,12 @@ def forward(self, x_1) -> torch.Tensor:
|
||||
|
||||
def forward(self, inpt_1) -> torch.Tensor:
|
||||
empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4])
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4]); add_tensor = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
|
||||
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor_1, [4]); add_tensor_1 = None
|
||||
return view_copy_default_2
|
||||
add = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None
|
||||
view_copy = torch.ops.aten.view_copy.default(add, [4])
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
|
||||
view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None
|
||||
return view_copy_2
|
||||
""")
|
||||
|
||||
def test_functionalize_fx_multi_out_op(self, device):
|
||||
@ -3316,13 +3316,13 @@ def forward(self, inpt_1) -> torch.Tensor:
|
||||
def forward(self, inpt_1) -> torch.Tensor:
|
||||
empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False)
|
||||
empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False)
|
||||
view_copy_default = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None
|
||||
aminmax_default = torch.ops.aten.aminmax.default(view_copy_default_1, dim = 0); view_copy_default_1 = None
|
||||
getitem = aminmax_default[0]
|
||||
getitem_1 = aminmax_default[1]; aminmax_default = None
|
||||
view_copy_default_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None
|
||||
return (view_copy_default_2, getitem)
|
||||
view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None
|
||||
aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0); view_copy_1 = None
|
||||
getitem = aminmax[0]
|
||||
getitem_1 = aminmax[1]; aminmax = None
|
||||
view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None
|
||||
return (view_copy_2, getitem)
|
||||
""")
|
||||
|
||||
def test_functionalize_fx_reapply_views_simple(self, device):
|
||||
@ -3341,11 +3341,11 @@ def forward(self, inpt_1) -> torch.Tensor:
|
||||
|
||||
def forward(self, x_1) -> torch.Tensor:
|
||||
ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False)
|
||||
view_default = torch.ops.aten.view.default(x_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
|
||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None
|
||||
copy__default = torch.ops.aten.copy_.default(x_1, view_default_1); x_1 = None
|
||||
return view_default_1
|
||||
view = torch.ops.aten.view.default(x_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
|
||||
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = None
|
||||
return view_1
|
||||
""")
|
||||
|
||||
def test_functionalize_nonfunctional_output(self, device):
|
||||
@ -3382,8 +3382,8 @@ def forward(self) -> torch.Tensor:
|
||||
|
||||
|
||||
def forward(self, a_1, b_1) -> torch.Tensor:
|
||||
index_tensor = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None
|
||||
return index_tensor
|
||||
index = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None
|
||||
return index
|
||||
""")
|
||||
|
||||
def test_functionalize_optional_tensorlist2(self, device):
|
||||
@ -3400,11 +3400,11 @@ def forward(self, a_1, b_1) -> torch.Tensor:
|
||||
|
||||
|
||||
def forward(self, a_1, b_1) -> torch.Tensor:
|
||||
unbind_int = torch.ops.aten.unbind.int(b_1); b_1 = None
|
||||
getitem = unbind_int[0]
|
||||
getitem_1 = unbind_int[1]; unbind_int = None
|
||||
index_tensor = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None
|
||||
return index_tensor
|
||||
unbind = torch.ops.aten.unbind.int(b_1); b_1 = None
|
||||
getitem = unbind[0]
|
||||
getitem_1 = unbind[1]; unbind = None
|
||||
index = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None
|
||||
return index
|
||||
""")
|
||||
|
||||
def test_resize_program_inputs(self, device):
|
||||
@ -3420,10 +3420,10 @@ def forward(self, a_1, b_1) -> torch.Tensor:
|
||||
|
||||
|
||||
def forward(self, x_1):
|
||||
resize_default = torch.ops.aten.resize.default(x_1, [10])
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(resize_default, 2); resize_default = None
|
||||
resize__default = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None
|
||||
copy__default = torch.ops.aten.copy_.default(resize__default, fill_scalar); resize__default = fill_scalar = None
|
||||
resize = torch.ops.aten.resize.default(x_1, [10])
|
||||
fill = torch.ops.aten.fill.Scalar(resize, 2); resize = None
|
||||
resize_ = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None
|
||||
copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = None
|
||||
return None
|
||||
""")
|
||||
|
||||
|
||||
@ -117,12 +117,12 @@ class TestFunctionalization(TestCase):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1)
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
|
||||
return add_tensor
|
||||
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2])
|
||||
mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1)
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
|
||||
return add
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -132,12 +132,12 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_default = torch.ops.aten.view.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
|
||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
|
||||
return add_tensor
|
||||
view = torch.ops.aten.view.default(a_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 2])
|
||||
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None
|
||||
return add
|
||||
""")
|
||||
|
||||
def test_simple_out(self):
|
||||
@ -157,11 +157,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
|
||||
return mul_tensor
|
||||
add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, add); add = None
|
||||
return mul
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -171,11 +171,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
|
||||
view = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
|
||||
empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
|
||||
return mul_tensor
|
||||
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, add); add = None
|
||||
return mul
|
||||
""")
|
||||
|
||||
def test_multi_out(self):
|
||||
@ -195,9 +195,9 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
||||
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
||||
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
|
||||
getitem = aminmax_default[0]
|
||||
getitem_1 = aminmax_default[1]; aminmax_default = None
|
||||
aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
|
||||
getitem = aminmax[0]
|
||||
getitem_1 = aminmax[1]; aminmax = None
|
||||
return getitem
|
||||
""")
|
||||
|
||||
@ -209,9 +209,9 @@ def forward(self, a_1):
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
||||
empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False)
|
||||
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
|
||||
getitem = aminmax_default[0]
|
||||
getitem_1 = aminmax_default[1]; aminmax_default = None
|
||||
aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
|
||||
getitem = aminmax[0]
|
||||
getitem_1 = aminmax[1]; aminmax = None
|
||||
return getitem
|
||||
""")
|
||||
|
||||
@ -232,11 +232,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(lift_fresh, [-1]); lift_fresh = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [3]); add_tensor = None
|
||||
return view_copy_default_1
|
||||
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||
view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
|
||||
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [3]); add = None
|
||||
return view_copy_1
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, inpt, reapply_views=True, run_reinplace=True)
|
||||
@ -246,11 +246,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
lift_fresh = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||
view_default = torch.ops.aten.view.default(lift_fresh, [-1]); lift_fresh = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(view_default, 1)
|
||||
view_default_1 = torch.ops.aten.view.default(view_default, [3]); view_default = None
|
||||
return view_default_1
|
||||
lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
||||
view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None
|
||||
add = torch.ops.aten.add_.Tensor(view, 1)
|
||||
view_1 = torch.ops.aten.view.default(view, [3]); view = None
|
||||
return view_1
|
||||
""")
|
||||
|
||||
|
||||
@ -282,11 +282,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, ones); ones = None
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||
return view_copy_default_1
|
||||
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(a_1, ones); ones = None
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
|
||||
return view_copy_1
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -296,11 +296,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_default = torch.ops.aten.view.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, ones); ones = None
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, add_tensor); a_1 = None
|
||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]); add_tensor = None
|
||||
return view_default_1
|
||||
view = torch.ops.aten.view.default(a_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(a_1, ones); ones = None
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None
|
||||
return view_1
|
||||
""")
|
||||
|
||||
# Some ops that are mutable are neither inplace nor out= ops.
|
||||
@ -315,14 +315,14 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
|
||||
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
|
||||
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
|
||||
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
|
||||
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
|
||||
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
|
||||
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
|
||||
_fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0)
|
||||
getitem = _fused_moving_avg_obs_fq_helper_functional[0]
|
||||
getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1]
|
||||
getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2]
|
||||
getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3]
|
||||
getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4]
|
||||
getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None
|
||||
return (getitem, getitem_1)
|
||||
""") # noqa: B950
|
||||
|
||||
@ -338,11 +338,11 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
|
||||
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
|
||||
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); add_tensor = None
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, as_strided_scatter_default); a_1 = None
|
||||
return as_strided_scatter_default
|
||||
as_strided_copy = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
|
||||
add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(a_1, add, [2], [2], 1); add = None
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, as_strided_scatter); a_1 = None
|
||||
return as_strided_scatter
|
||||
""")
|
||||
|
||||
def test_tensor_list_composite(self):
|
||||
@ -357,8 +357,8 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
block_diag_default = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
|
||||
return block_diag_default
|
||||
block_diag = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
|
||||
return block_diag
|
||||
""")
|
||||
|
||||
def test_cat(self):
|
||||
@ -374,8 +374,8 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
|
||||
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
|
||||
return cat_default
|
||||
cat = torch.ops.aten.cat.default([a_1]); a_1 = None
|
||||
return cat
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -385,8 +385,8 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False)
|
||||
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
|
||||
return cat_default
|
||||
cat = torch.ops.aten.cat.default([a_1]); a_1 = None
|
||||
return cat
|
||||
""")
|
||||
|
||||
|
||||
@ -406,11 +406,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
||||
clone_default = torch.ops.aten.clone.default(a_1)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(clone_default); clone_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
|
||||
return mul_tensor
|
||||
clone = torch.ops.aten.clone.default(a_1)
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(clone); clone = None
|
||||
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
|
||||
mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
|
||||
return mul
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -420,11 +420,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
||||
clone_default = torch.ops.aten.clone.default(a_1)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(clone_default); clone_default = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(diagonal_default, ones); diagonal_default = ones = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
|
||||
return mul_tensor
|
||||
clone = torch.ops.aten.clone.default(a_1)
|
||||
diagonal = torch.ops.aten.diagonal.default(clone); clone = None
|
||||
add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None
|
||||
mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None
|
||||
return mul
|
||||
""")
|
||||
|
||||
def test_diagonal_mutated_input(self):
|
||||
@ -443,11 +443,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
|
||||
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None
|
||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); add_tensor = None
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, diagonal_scatter_default); a_1 = None
|
||||
return diagonal_scatter_default
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(a_1)
|
||||
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
|
||||
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(a_1, add); add = None
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, diagonal_scatter); a_1 = None
|
||||
return diagonal_scatter
|
||||
""")
|
||||
|
||||
def test_split(self):
|
||||
@ -467,19 +467,19 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
|
||||
split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2)
|
||||
getitem = split_copy_tensor[0]
|
||||
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None
|
||||
split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
|
||||
getitem_2 = split_copy_tensor_1[0]
|
||||
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
|
||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
|
||||
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); diagonal_scatter_default = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default)
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, slice_scatter_default); a_1 = slice_scatter_default = None
|
||||
return add_tensor
|
||||
split_copy = torch.ops.aten.split_copy.Tensor(a_1, 2)
|
||||
getitem = split_copy[0]
|
||||
getitem_1 = split_copy[1]; split_copy = None
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
|
||||
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
|
||||
split_copy_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
|
||||
getitem_2 = split_copy_1[0]
|
||||
getitem_3 = split_copy_1[1]; split_copy_1 = None
|
||||
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = None
|
||||
slice_scatter = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
|
||||
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, slice_scatter); a_1 = slice_scatter = None
|
||||
return add
|
||||
""") # noqa: B950
|
||||
|
||||
def test_view_inplace(self):
|
||||
@ -498,14 +498,14 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
|
||||
transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
|
||||
select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(select_copy_int, ones); select_copy_int = ones = None
|
||||
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
|
||||
select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None
|
||||
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None
|
||||
transpose_copy_int_3 = torch.ops.aten.transpose_copy.int(transpose_copy_int_2, 1, 0); transpose_copy_int_2 = None
|
||||
return transpose_copy_int_3
|
||||
transpose_copy = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
|
||||
select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None
|
||||
add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None
|
||||
transpose_copy_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
|
||||
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
|
||||
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
|
||||
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
|
||||
return transpose_copy_3
|
||||
""") # noqa: B950
|
||||
|
||||
def test_optional_tensor_list(self):
|
||||
@ -524,13 +524,13 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8])
|
||||
view_copy = torch.ops.aten.view_copy.default(a_1, [8])
|
||||
arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False)
|
||||
arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
|
||||
return index_put_default
|
||||
index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2])
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
|
||||
return index_put
|
||||
""") # noqa: B950
|
||||
|
||||
def test_scalars(self):
|
||||
@ -550,13 +550,13 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
|
||||
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, view_copy_default_1); a_1 = view_copy_default_1 = None
|
||||
return div_tensor
|
||||
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, 2)
|
||||
div = torch.ops.aten.div.Tensor(mul, 1); mul = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None
|
||||
return div
|
||||
""")
|
||||
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
@ -574,10 +574,10 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
|
||||
ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
|
||||
return _to_copy_default
|
||||
clone = torch.ops.aten.clone.default(a_1); a_1 = None
|
||||
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
|
||||
return _to_copy
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -586,10 +586,10 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
clone_default = torch.ops.aten.clone.default(a_1); a_1 = None
|
||||
ge_scalar = torch.ops.aten.ge.Scalar(clone_default, 0); clone_default = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
|
||||
return _to_copy_default
|
||||
clone = torch.ops.aten.clone.default(a_1); a_1 = None
|
||||
ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None
|
||||
return _to_copy
|
||||
""") # noqa: B950
|
||||
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
@ -622,8 +622,8 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
return view_copy_default
|
||||
view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
|
||||
return view_copy
|
||||
""")
|
||||
|
||||
def test_everything(self):
|
||||
@ -648,35 +648,35 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8])
|
||||
_reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None
|
||||
transpose_copy_int = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default, 1, 0)
|
||||
unsqueeze_copy_default = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int, 0); transpose_copy_int = None
|
||||
squeeze_copy_default = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default); unsqueeze_copy_default = None
|
||||
split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None
|
||||
getitem = split_copy_tensor[0]
|
||||
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
|
||||
select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None
|
||||
clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format)
|
||||
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [8]); add_tensor = None
|
||||
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_1, [2, 4], [4, 1]); view_copy_default_1 = None
|
||||
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default_1, 1, 0); _reshape_alias_copy_default_1 = None
|
||||
unsqueeze_copy_default_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int_1, 0); transpose_copy_int_1 = None
|
||||
squeeze_copy_default_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default_1); unsqueeze_copy_default_1 = None
|
||||
slice_scatter_default = torch.ops.aten.slice_scatter.default(squeeze_copy_default_1, add_tensor_1, 0, 0, 2); squeeze_copy_default_1 = None
|
||||
unsqueeze_copy_default_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter_default, 0); slice_scatter_default = None
|
||||
squeeze_copy_dim = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_default_2, 0); unsqueeze_copy_default_2 = None
|
||||
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_dim, 1, 0); squeeze_copy_dim = None
|
||||
_reshape_alias_copy_default_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_int_2, [8], [1]); transpose_copy_int_2 = None
|
||||
view_copy_default_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_default_2, [4, 2]); _reshape_alias_copy_default_2 = None
|
||||
view_copy_default_3 = torch.ops.aten.view_copy.default(view_copy_default_2, [8]); view_copy_default_2 = None
|
||||
_reshape_alias_copy_default_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_3, [2, 4], [4, 1]); view_copy_default_3 = None
|
||||
select_copy_int_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_default_3, 0, 0); _reshape_alias_copy_default_3 = None
|
||||
add_tensor_2 = torch.ops.aten.add.Tensor(select_copy_int_1, _unsafe_view_default); select_copy_int_1 = _unsafe_view_default = None
|
||||
return add_tensor_1
|
||||
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
view_copy = torch.ops.aten.view_copy.default(add, [8])
|
||||
_reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None
|
||||
transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0)
|
||||
unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None
|
||||
squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None
|
||||
split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None
|
||||
getitem = split_copy[0]
|
||||
getitem_1 = split_copy[1]; split_copy = None
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
|
||||
select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None
|
||||
clone = torch.ops.aten.clone.default(add_1, memory_format = torch.contiguous_format)
|
||||
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None
|
||||
_reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None
|
||||
transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_1, 1, 0); _reshape_alias_copy_1 = None
|
||||
unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None
|
||||
squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None
|
||||
slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None
|
||||
unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None
|
||||
squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None
|
||||
transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None
|
||||
_reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None
|
||||
view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_2, [4, 2]); _reshape_alias_copy_2 = None
|
||||
view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None
|
||||
_reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None
|
||||
select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_3, 0, 0); _reshape_alias_copy_3 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(select_copy_1, _unsafe_view); select_copy_1 = _unsafe_view = None
|
||||
return add_1
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -686,33 +686,33 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
view_default = torch.ops.aten.view.default(add_tensor, [8])
|
||||
_reshape_alias_default = torch.ops.aten._reshape_alias.default(view_default, [2, 4], [4, 1]); view_default = None
|
||||
transpose_int = torch.ops.aten.transpose.int(_reshape_alias_default, 1, 0)
|
||||
unsqueeze_default = torch.ops.aten.unsqueeze.default(transpose_int, 0); transpose_int = None
|
||||
squeeze_default = torch.ops.aten.squeeze.default(unsqueeze_default); unsqueeze_default = None
|
||||
split_tensor = torch.ops.aten.split.Tensor(squeeze_default, 2); squeeze_default = None
|
||||
getitem = split_tensor[0]
|
||||
getitem_1 = split_tensor[1]; split_tensor = None
|
||||
add_tensor_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
|
||||
select_int = torch.ops.aten.select.int(_reshape_alias_default, 0, 0); _reshape_alias_default = None
|
||||
clone_default = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
|
||||
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
|
||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [8]); add_tensor = None
|
||||
_reshape_alias_default_1 = torch.ops.aten._reshape_alias.default(view_default_1, [2, 4], [4, 1]); view_default_1 = None
|
||||
transpose_int_1 = torch.ops.aten.transpose.int(_reshape_alias_default_1, 1, 0); _reshape_alias_default_1 = None
|
||||
unsqueeze_default_1 = torch.ops.aten.unsqueeze.default(transpose_int_1, 0); transpose_int_1 = None
|
||||
squeeze_default_1 = torch.ops.aten.squeeze.default(unsqueeze_default_1); unsqueeze_default_1 = None
|
||||
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(squeeze_default_1, 0); squeeze_default_1 = None
|
||||
squeeze_dim = torch.ops.aten.squeeze.dim(unsqueeze_default_2, 0); unsqueeze_default_2 = None
|
||||
transpose_int_2 = torch.ops.aten.transpose.int(squeeze_dim, 1, 0); squeeze_dim = None
|
||||
_reshape_alias_default_2 = torch.ops.aten._reshape_alias.default(transpose_int_2, [8], [1]); transpose_int_2 = None
|
||||
view_default_2 = torch.ops.aten.view.default(_reshape_alias_default_2, [4, 2]); _reshape_alias_default_2 = None
|
||||
view_default_3 = torch.ops.aten.view.default(view_default_2, [8]); view_default_2 = None
|
||||
_reshape_alias_default_3 = torch.ops.aten._reshape_alias.default(view_default_3, [2, 4], [4, 1]); view_default_3 = None
|
||||
select_int_1 = torch.ops.aten.select.int(_reshape_alias_default_3, 0, 0); _reshape_alias_default_3 = None
|
||||
add_tensor_2 = torch.ops.aten.add.Tensor(select_int_1, _unsafe_view_default); select_int_1 = _unsafe_view_default = None
|
||||
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
view = torch.ops.aten.view.default(add, [8])
|
||||
_reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None
|
||||
transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0)
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None
|
||||
squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None
|
||||
split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None
|
||||
getitem = split[0]
|
||||
getitem_1 = split[1]; split = None
|
||||
add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None
|
||||
select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None
|
||||
clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format)
|
||||
_unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None
|
||||
view_1 = torch.ops.aten.view.default(add, [8]); add = None
|
||||
_reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None
|
||||
transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None
|
||||
squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None
|
||||
unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None
|
||||
squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None
|
||||
transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None
|
||||
_reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None
|
||||
view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None
|
||||
view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None
|
||||
_reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None
|
||||
select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None
|
||||
return getitem
|
||||
""")
|
||||
|
||||
@ -731,12 +731,12 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False)
|
||||
view_default = torch.ops.aten.view.default(a_1, [4, 2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None
|
||||
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1)
|
||||
copy__default = torch.ops.aten.copy_.default(a_1, view_default_1); a_1 = view_default_1 = None
|
||||
return add_tensor
|
||||
view = torch.ops.aten.view.default(a_1, [4, 2])
|
||||
add = torch.ops.aten.add.Tensor(view, ones); view = ones = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 2])
|
||||
mul = torch.ops.aten.mul.Tensor(view_1, view_1)
|
||||
copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None
|
||||
return add
|
||||
""")
|
||||
|
||||
def test_aliases_maintained_after_pass_when_reapplying_views(self):
|
||||
@ -781,9 +781,9 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
return add_tensor
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
return add
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
|
||||
@ -793,9 +793,9 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
return add_tensor
|
||||
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
return add
|
||||
""")
|
||||
|
||||
# Test 2: copy_() with same dtype, different shape
|
||||
@ -807,10 +807,10 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
|
||||
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
|
||||
return add_tensor
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
|
||||
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
|
||||
return add
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(1), reapply_views=True, run_reinplace=True)
|
||||
@ -820,10 +820,10 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2])
|
||||
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
|
||||
return expand_copy_default
|
||||
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2])
|
||||
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
|
||||
return expand_copy
|
||||
""")
|
||||
|
||||
# Test 3: copy_() with different dtype, same shape
|
||||
@ -835,10 +835,10 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add.Tensor(_to_copy_default, a_1); _to_copy_default = a_1 = None
|
||||
return add_tensor
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
add = torch.ops.aten.add.Tensor(_to_copy, a_1); _to_copy = a_1 = None
|
||||
return add
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
||||
@ -848,10 +848,10 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
add_tensor = torch.ops.aten.add_.Tensor(_to_copy_default, a_1); a_1 = None
|
||||
return _to_copy_default
|
||||
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
add = torch.ops.aten.add_.Tensor(_to_copy, a_1); a_1 = None
|
||||
return _to_copy
|
||||
""") # noqa: B950
|
||||
|
||||
# Test 4: copy_() with different dtype, different shape
|
||||
@ -863,11 +863,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None
|
||||
return add_tensor
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
|
||||
add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None
|
||||
return add
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True)
|
||||
@ -877,11 +877,11 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False)
|
||||
diagonal_default = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
_to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(expand_copy_default, a_1); a_1 = None
|
||||
return expand_copy_default
|
||||
diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
||||
expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None
|
||||
add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None
|
||||
return expand_copy
|
||||
""") # noqa: B950
|
||||
|
||||
def test_expand_symint(self):
|
||||
@ -897,8 +897,8 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None
|
||||
return expand_copy_default
|
||||
expand_copy = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None
|
||||
return expand_copy
|
||||
""")
|
||||
|
||||
def test_fill_(self):
|
||||
@ -915,11 +915,11 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(add_tensor)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
|
||||
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
|
||||
return diagonal_scatter_default
|
||||
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
diagonal_copy = torch.ops.aten.diagonal_copy.default(add)
|
||||
fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None
|
||||
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None
|
||||
return diagonal_scatter
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -928,10 +928,10 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
diagonal_default = torch.ops.aten.diagonal.default(add_tensor)
|
||||
fill_scalar = torch.ops.aten.fill_.Scalar(diagonal_default, 0); diagonal_default = None
|
||||
return add_tensor
|
||||
add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
|
||||
diagonal = torch.ops.aten.diagonal.default(add)
|
||||
fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None
|
||||
return add
|
||||
""")
|
||||
|
||||
def test_resize_smaller(self):
|
||||
@ -952,21 +952,21 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4, 4])
|
||||
resize_default = torch.ops.aten.resize.default(view_copy_default, [3, 3])
|
||||
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(view_copy_default, [3, 3], [3, 1]); view_copy_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(as_strided_copy_default, [-1]); as_strided_copy_default = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
|
||||
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor, [4, 4]); add_tensor = None
|
||||
as_strided_copy_default_1 = torch.ops.aten.as_strided_copy.default(view_copy_default_2, [3, 3], [3, 1])
|
||||
view_copy_default_3 = torch.ops.aten.view_copy.default(add_tensor_1, [3, 3]); add_tensor_1 = None
|
||||
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(view_copy_default_2, view_copy_default_3, [3, 3], [3, 1]); view_copy_default_2 = view_copy_default_3 = None
|
||||
view_copy_default_4 = torch.ops.aten.view_copy.default(as_strided_scatter_default, [8, 2]); as_strided_scatter_default = None
|
||||
view_copy_default_5 = torch.ops.aten.view_copy.default(view_copy_default_4, [4, 4]); view_copy_default_4 = None
|
||||
as_strided_copy_default_2 = torch.ops.aten.as_strided_copy.default(view_copy_default_5, [3, 3], [3, 1]); view_copy_default_5 = None
|
||||
add_tensor_2 = torch.ops.aten.add.Tensor(as_strided_copy_default_2, 1); as_strided_copy_default_2 = None
|
||||
return add_tensor_2
|
||||
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
view_copy = torch.ops.aten.view_copy.default(add, [4, 4])
|
||||
resize = torch.ops.aten.resize.default(view_copy, [3, 3])
|
||||
as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(as_strided_copy, [-1]); as_strided_copy = None
|
||||
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None
|
||||
view_copy_2 = torch.ops.aten.view_copy.default(add, [4, 4]); add = None
|
||||
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(view_copy_2, [3, 3], [3, 1])
|
||||
view_copy_3 = torch.ops.aten.view_copy.default(add_1, [3, 3]); add_1 = None
|
||||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(view_copy_2, view_copy_3, [3, 3], [3, 1]); view_copy_2 = view_copy_3 = None
|
||||
view_copy_4 = torch.ops.aten.view_copy.default(as_strided_scatter, [8, 2]); as_strided_scatter = None
|
||||
view_copy_5 = torch.ops.aten.view_copy.default(view_copy_4, [4, 4]); view_copy_4 = None
|
||||
as_strided_copy_2 = torch.ops.aten.as_strided_copy.default(view_copy_5, [3, 3], [3, 1]); view_copy_5 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(as_strided_copy_2, 1); as_strided_copy_2 = None
|
||||
return add_2
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -975,20 +975,20 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
view_default = torch.ops.aten.view.default(add_tensor, [4, 4])
|
||||
resize_default = torch.ops.aten.resize.default(view_default, [3, 3])
|
||||
as_strided_default = torch.ops.aten.as_strided.default(view_default, [3, 3], [3, 1]); view_default = None
|
||||
view_default_1 = torch.ops.aten.view.default(as_strided_default, [-1]); as_strided_default = None
|
||||
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default_1, 1)
|
||||
view_default_2 = torch.ops.aten.view.default(add_tensor, [4, 4]); add_tensor = None
|
||||
as_strided_default_1 = torch.ops.aten.as_strided.default(view_default_2, [3, 3], [3, 1])
|
||||
view_default_3 = torch.ops.aten.view.default(view_default_1, [3, 3]); view_default_1 = None
|
||||
view_default_4 = torch.ops.aten.view.default(view_default_2, [8, 2]); view_default_2 = None
|
||||
view_default_5 = torch.ops.aten.view.default(view_default_4, [4, 4]); view_default_4 = None
|
||||
as_strided_default_2 = torch.ops.aten.as_strided.default(view_default_5, [3, 3], [3, 1]); view_default_5 = None
|
||||
add_tensor_2 = torch.ops.aten.add_.Tensor(as_strided_default_2, 1)
|
||||
return as_strided_default_2
|
||||
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
view = torch.ops.aten.view.default(add, [4, 4])
|
||||
resize = torch.ops.aten.resize.default(view, [3, 3])
|
||||
as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None
|
||||
view_1 = torch.ops.aten.view.default(as_strided, [-1]); as_strided = None
|
||||
add_1 = torch.ops.aten.add_.Tensor(view_1, 1)
|
||||
view_2 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
||||
as_strided_1 = torch.ops.aten.as_strided.default(view_2, [3, 3], [3, 1])
|
||||
view_3 = torch.ops.aten.view.default(view_1, [3, 3]); view_1 = None
|
||||
view_4 = torch.ops.aten.view.default(view_2, [8, 2]); view_2 = None
|
||||
view_5 = torch.ops.aten.view.default(view_4, [4, 4]); view_4 = None
|
||||
as_strided_2 = torch.ops.aten.as_strided.default(view_5, [3, 3], [3, 1]); view_5 = None
|
||||
add_2 = torch.ops.aten.add_.Tensor(as_strided_2, 1)
|
||||
return as_strided_2
|
||||
""")
|
||||
|
||||
def test_resize_larger_valid(self):
|
||||
@ -1015,13 +1015,13 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
resize_default = torch.ops.aten.resize.default(add_tensor, [5, 5]); add_tensor = None
|
||||
view_copy_default = torch.ops.aten.view_copy.default(resize_default, [25]); resize_default = None
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(view_copy_default, 1); view_copy_default = None
|
||||
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
|
||||
return (view_copy_default_1, add_tensor_1)
|
||||
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
resize = torch.ops.aten.resize.default(add, [5, 5]); add = None
|
||||
view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None
|
||||
fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None
|
||||
view_copy_1 = torch.ops.aten.view_copy.default(fill, [5, 5]); fill = None
|
||||
add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1)
|
||||
return (view_copy_1, add_1)
|
||||
""")
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(8, 2), reapply_views=True, run_reinplace=True)
|
||||
@ -1030,13 +1030,13 @@ def forward(self, a_1):
|
||||
|
||||
|
||||
def forward(self, a_1):
|
||||
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
resize_default = torch.ops.aten.resize_.default(add_tensor, [5, 5])
|
||||
view_default = torch.ops.aten.view.default(add_tensor, [25]); add_tensor = None
|
||||
fill_scalar = torch.ops.aten.fill_.Scalar(view_default, 1)
|
||||
view_default_1 = torch.ops.aten.view.default(view_default, [5, 5]); view_default = None
|
||||
add_tensor_1 = torch.ops.aten.add.Tensor(view_default_1, 1)
|
||||
return (view_default_1, add_tensor_1)
|
||||
add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
|
||||
resize = torch.ops.aten.resize_.default(add, [5, 5])
|
||||
view = torch.ops.aten.view.default(add, [25]); add = None
|
||||
fill = torch.ops.aten.fill_.Scalar(view, 1)
|
||||
view_1 = torch.ops.aten.view.default(view, [5, 5]); view = None
|
||||
add_1 = torch.ops.aten.add.Tensor(view_1, 1)
|
||||
return (view_1, add_1)
|
||||
""")
|
||||
|
||||
def test_resize_larger_invalid(self):
|
||||
@ -1115,10 +1115,10 @@ $3 = torch._ops.aten.add.Tensor($2, 1)""")
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
|
||||
select_copy_int = torch.ops.aten.select_copy.int(zeros, 0, 5)
|
||||
fill_scalar = torch.ops.aten.fill.Scalar(select_copy_int, 1); select_copy_int = None
|
||||
select_scatter_default = torch.ops.aten.select_scatter.default(zeros, fill_scalar, 0, 5); zeros = fill_scalar = None
|
||||
return select_scatter_default
|
||||
select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5)
|
||||
fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None
|
||||
select_scatter = torch.ops.aten.select_scatter.default(zeros, fill, 0, 5); zeros = fill = None
|
||||
return select_scatter
|
||||
""") # noqa: B950
|
||||
|
||||
reinplaced_logs = self.get_logs(f, torch.ones(2), reapply_views=True, run_reinplace=True)
|
||||
@ -1128,8 +1128,8 @@ def forward(self, a_1):
|
||||
|
||||
def forward(self, a_1):
|
||||
zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False)
|
||||
select_int = torch.ops.aten.select.int(zeros, 0, 5)
|
||||
fill_scalar = torch.ops.aten.fill_.Scalar(select_int, 1); select_int = None
|
||||
select = torch.ops.aten.select.int(zeros, 0, 5)
|
||||
fill = torch.ops.aten.fill_.Scalar(select, 1); select = None
|
||||
return zeros
|
||||
""")
|
||||
|
||||
|
||||
@ -30,9 +30,9 @@ class TestReinplacePass(TestCase):
|
||||
|
||||
|
||||
def forward(self, x_1):
|
||||
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(clone_default, 1)
|
||||
return clone_default
|
||||
clone = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
add = torch.ops.aten.add_.Tensor(clone, 1)
|
||||
return clone
|
||||
""")
|
||||
|
||||
|
||||
@ -56,11 +56,11 @@ def forward(self, x_1):
|
||||
|
||||
|
||||
def forward(self, x_1):
|
||||
clone_default = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
view_default = torch.ops.aten.view.default(clone_default, [-1])
|
||||
add_tensor = torch.ops.aten.add.Tensor(clone_default, 1); clone_default = None
|
||||
add_tensor_1 = torch.ops.aten.add_.Tensor(view_default, 1)
|
||||
return view_default
|
||||
clone = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
view = torch.ops.aten.view.default(clone, [-1])
|
||||
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
||||
add_1 = torch.ops.aten.add_.Tensor(view, 1)
|
||||
return view
|
||||
""")
|
||||
|
||||
def test_reinplace_different_metadata(self):
|
||||
@ -82,10 +82,10 @@ def forward(self, x_1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(clone_default, 1)
|
||||
ge_tensor = torch.ops.aten.ge.Tensor(add_tensor, clone_default); add_tensor = clone_default = None
|
||||
return ge_tensor
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
add = torch.ops.aten.add.Tensor(clone, 1)
|
||||
ge = torch.ops.aten.ge.Tensor(add, clone); add = clone = None
|
||||
return ge
|
||||
""")
|
||||
|
||||
def test_reinplace_overlapping_memory(self):
|
||||
@ -105,10 +105,10 @@ def forward(self, a__1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
expand_default = torch.ops.aten.expand.default(clone_default, [4, 4]); clone_default = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(expand_default, 1); expand_default = None
|
||||
return add_tensor
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
expand = torch.ops.aten.expand.default(clone, [4, 4]); clone = None
|
||||
add = torch.ops.aten.add.Tensor(expand, 1); expand = None
|
||||
return add
|
||||
""")
|
||||
|
||||
# This test won't actually run in CI, because it requires functionalize() from functorch.
|
||||
@ -217,13 +217,13 @@ def forward(self, a__1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||
add_tensor = torch.ops.aten.add_.Tensor(select_int_1, 1); select_int_1 = None
|
||||
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
|
||||
return as_strided_default
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = None
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
|
||||
return as_strided
|
||||
""")
|
||||
|
||||
# Test example where we have a scatter op, where the base tensor
|
||||
@ -253,15 +253,15 @@ def forward(self, a__1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
|
||||
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 1); clone_default = None
|
||||
select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 0)
|
||||
copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None
|
||||
return as_strided_default
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
|
||||
select_int = torch.ops.aten.select.int(as_strided, 0, 0)
|
||||
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
|
||||
return as_strided
|
||||
""") # noqa: B950
|
||||
|
||||
def test_reinplace_scatter_twice_with_different_view_op_invalid2(self):
|
||||
@ -286,15 +286,15 @@ def forward(self, a__1):
|
||||
|
||||
|
||||
def forward(self, a__1):
|
||||
clone_default = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_tensor = torch.ops.aten.slice.Tensor(clone_default, 0, 0, 9223372036854775807)
|
||||
select_int = torch.ops.aten.select.int(slice_tensor, 1, 1); slice_tensor = None
|
||||
select_int_1 = torch.ops.aten.select.int(select_int, 0, 1); select_int = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(select_int_1, 1); select_int_1 = None
|
||||
as_strided_default = torch.ops.aten.as_strided.default(clone_default, [4], [4], 0); clone_default = None
|
||||
select_int_2 = torch.ops.aten.select.int(as_strided_default, 0, 1)
|
||||
copy__default = torch.ops.aten.copy_.default(select_int_2, add_tensor); select_int_2 = add_tensor = None
|
||||
return as_strided_default
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
|
||||
select_int = torch.ops.aten.select.int(as_strided, 0, 1)
|
||||
copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = None
|
||||
return as_strided
|
||||
""") # noqa: B950
|
||||
|
||||
|
||||
|
||||
@ -327,7 +327,7 @@ class TestPrims(TestCase):
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(node.name == "add_default")
|
||||
self.assertTrue(node.name == "add")
|
||||
self.assertTrue(node.target == torch.ops.nvprims.add.default)
|
||||
self.assertFalse(node.target == torch.ops.prims.add.default)
|
||||
self.assertFalse(node.target == torch.ops.aten.add.default)
|
||||
|
||||
@ -314,8 +314,8 @@ class TestGenericProxyTensor(TestCase):
|
||||
|
||||
def forward(self, x_1):
|
||||
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
||||
copy__default = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
|
||||
return copy__default
|
||||
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
|
||||
return copy_
|
||||
""")
|
||||
|
||||
def test_make_fx_reentrant_dispatch(self):
|
||||
@ -589,13 +589,19 @@ def forward(self, x_1):
|
||||
)
|
||||
|
||||
def test_trace_subclasses(self):
|
||||
def f(x):
|
||||
def f1(x):
|
||||
x = UnwrapTensor(x)
|
||||
y = x * 2
|
||||
return y
|
||||
|
||||
def f2(x):
|
||||
wrapped = UnwrapTensor(x)
|
||||
y = x * wrapped
|
||||
return y
|
||||
|
||||
inp = [torch.randn(5)]
|
||||
self._test(f, inp)
|
||||
self._test(f1, inp)
|
||||
self._test(f2, inp)
|
||||
|
||||
def test_partial_decomp(self):
|
||||
def f(a, b, c):
|
||||
@ -616,6 +622,19 @@ def forward(self, x_1):
|
||||
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
|
||||
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
|
||||
|
||||
def test_decomp_of_capture(self):
|
||||
val = torch.randn(5)
|
||||
|
||||
def f(x):
|
||||
return x.t() + val.t()
|
||||
|
||||
def nop(x):
|
||||
return x.cos()
|
||||
|
||||
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
|
||||
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
|
||||
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||
def test_amp_cache(self):
|
||||
layer = torch.nn.Conv2d(3, 3, 3).cuda()
|
||||
@ -771,9 +790,9 @@ def forward(self, a_1):
|
||||
mul = sym_size * 2; sym_size = None
|
||||
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
sym_size_1 = torch.ops.aten.sym_size(empty, 0)
|
||||
detach_default = torch.ops.aten.detach.default(empty); empty = None
|
||||
sym_size_2 = torch.ops.aten.sym_size(detach_default, 0)
|
||||
return detach_default""")
|
||||
detach = torch.ops.aten.detach.default(empty); empty = None
|
||||
sym_size_2 = torch.ops.aten.sym_size(detach, 0)
|
||||
return detach""")
|
||||
|
||||
def test_cat(self):
|
||||
def f(a, b):
|
||||
|
||||
@ -182,24 +182,29 @@ def fetch_symint_proxy(tracer):
|
||||
def fetch_tensor_proxy(tracer):
|
||||
return lambda t: get_proxy_slot(t, tracer, t)
|
||||
|
||||
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
|
||||
|
||||
def proxy_call(proxy_mode, func_overload, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
def proxy_call(proxy_mode, func, args, kwargs):
|
||||
def can_handle_tensor(x):
|
||||
return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
|
||||
|
||||
func = func_overload.overloadpacket
|
||||
if func_overload in CURRENT_DECOMPOSITION_TABLE:
|
||||
# If there are any tensor subclasses, we need to handle those tensor subclasses first
|
||||
# TODO: we could use types to test this
|
||||
if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)):
|
||||
return NotImplemented
|
||||
|
||||
if func in CURRENT_DECOMPOSITION_TABLE:
|
||||
with proxy_mode.restore():
|
||||
r = CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
|
||||
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
|
||||
# Some of these are not "real" aten ops and will fail if we
|
||||
# call _dispatch_has_kernel_for_dispatch_key on them.
|
||||
# This list is probably incomplete
|
||||
if func_overload not in [torch.ops.aten.size.default]:
|
||||
if func not in [torch.ops.aten.size.default]:
|
||||
with proxy_mode.restore():
|
||||
r = func_overload.decompose(*args, **kwargs)
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
|
||||
@ -217,14 +222,14 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
|
||||
and pytree.tree_all_only(SymInt, lambda _: False, (args, kwargs))
|
||||
)
|
||||
|
||||
if torch.Tag.data_dependent_output in func_overload.tags: # type: ignore[attr-defined]
|
||||
if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined]
|
||||
# Check if all of the Tensor inputs are constants
|
||||
if all_constant:
|
||||
const_args, const_kwargs = pytree.tree_map_only(
|
||||
_ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
|
||||
)
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
return func_overload(*const_args, **const_kwargs)
|
||||
return func(*const_args, **const_kwargs)
|
||||
raise RuntimeError(
|
||||
"It appears that you're trying to get value out of a tracing tensor - erroring out! "
|
||||
"It's likely that this is caused by data-dependent control flow or similar."
|
||||
@ -235,21 +240,59 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
|
||||
fetch_symint_proxy(proxy_mode.tracer),
|
||||
pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
|
||||
)
|
||||
proxy_out = func_overload(*proxy_args, **proxy_kwargs)
|
||||
|
||||
# When we trace through a torch.tensor invocation, you never actually
|
||||
# see a torch.ops.aten.tensor call. Instead, the way this function is
|
||||
# implemented internally is that we allocate a plain tensor (this is
|
||||
# *guaranteed* to be a plain tensor, we disable all modes when doing
|
||||
# so), and then call at::lift_fresh on it (to give modes a chance to do
|
||||
# their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed
|
||||
# to be freshly allocated, so we want lift_fresh to be a no-op (directly
|
||||
# returning the input argument).
|
||||
#
|
||||
# Here is the basic problem: when we trace this sequence of executions
|
||||
# into an FX graph, what happens to this call sequence? Traditionally,
|
||||
# tensor constants get interned as buffers on the FX GraphModule. But
|
||||
# this is dangerous. Consider:
|
||||
#
|
||||
# x = torch.tensor(1)
|
||||
# x.add_(2)
|
||||
#
|
||||
# Naively, this traces into:
|
||||
#
|
||||
# t = self._tensor_constant0 # initialized to torch.tensor(1)
|
||||
# x = torch.ops.aten.lift_fresh(t)
|
||||
# x.add_(2)
|
||||
#
|
||||
# If lift_fresh returns t directly, the subsequent add_ call will
|
||||
# modify the tensor constant. Really, the problem is we've violated
|
||||
# the invariant the the argument to lift is fresh. So what we should
|
||||
# preserve the invariant by replacing lift_fresh with lift_fresh_copy:
|
||||
#
|
||||
# t = self._tensor_constant0 # initialized to torch.tensor(1)
|
||||
# x = torch.ops.aten.lift_fresh_copy(t)
|
||||
# x.add_(2)
|
||||
#
|
||||
# This is what the overload modification does.
|
||||
if func is torch.ops.aten.lift_fresh.default:
|
||||
func = torch.ops.aten.lift_fresh_copy.default
|
||||
|
||||
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
|
||||
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
|
||||
|
||||
# This makes DCE marginally less likely to DCE inplace operations.
|
||||
# It is not strictly necessary
|
||||
# Kind of a hacky way to test if an op is in-place or not
|
||||
if func.__name__[-1] == "_" and func.__name__[0] != "_":
|
||||
if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_":
|
||||
if isinstance(args[0], List):
|
||||
# e.g., c10d::allreduce_ returns a list of tensors as the first element
|
||||
# in the output.
|
||||
for i, a in enumerate(args[0]):
|
||||
a.proxy = proxy_out[0][i]
|
||||
else:
|
||||
# This makes DCE marginally less likely to DCE inplace operations.
|
||||
# It is not strictly necessary
|
||||
args[0].proxy = proxy_out
|
||||
|
||||
out = func_overload(*args, **kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# In some circumstances, we will be tracing in a situation where a tensor
|
||||
# is *statically* known to be a constant (currently, this only happens if
|
||||
@ -275,18 +318,27 @@ def proxy_call(proxy_mode, func_overload, args, kwargs=None):
|
||||
any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
|
||||
|
||||
constant = None
|
||||
# NB: do NOT include factories as constants
|
||||
if (
|
||||
torch.Tag.nondeterministic_seeded not in func_overload.tags # type: ignore[attr-defined]
|
||||
|
||||
# If this is a lift, the input tensor is guaranteed to be a
|
||||
# constant, so we keep a copy of the original argument along so
|
||||
# we can query it if we're asked to item() it at some later point
|
||||
if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT:
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
constant = args[0].clone()
|
||||
elif (
|
||||
torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined]
|
||||
and all_constant
|
||||
and any_constant
|
||||
and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out)
|
||||
):
|
||||
# NB: do NOT include factories as constants
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
const_args, const_kwargs = pytree.tree_map_only(
|
||||
_ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
|
||||
)
|
||||
constant = func_overload(*const_args, **const_kwargs)
|
||||
constant = func(*const_args, **const_kwargs)
|
||||
else:
|
||||
constant = None
|
||||
|
||||
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
|
||||
return out
|
||||
@ -367,9 +419,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
self.sym_mode = ProxySymDispatchMode(tracer)
|
||||
self.trace_state = {}
|
||||
|
||||
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
with self.sym_mode.enable(False):
|
||||
return self.inner_torch_dispatch(func_overload, types, args, kwargs)
|
||||
return self.inner_torch_dispatch(func, types, args, kwargs)
|
||||
|
||||
@contextmanager
|
||||
def restore(self):
|
||||
@ -377,90 +429,18 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
with super().restore():
|
||||
yield
|
||||
|
||||
def inner_torch_dispatch(self, func_overload, types, args=(), kwargs=None):
|
||||
def inner_torch_dispatch(self, func, types, args=(), kwargs=None):
|
||||
if not self.enable_tracing:
|
||||
return func_overload(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if symbolic_shapes.is_symbolic_op(func_overload):
|
||||
if symbolic_shapes.is_symbolic_op(func):
|
||||
with self.restore():
|
||||
return symbolic_shapes.handle_symbolic_op(func_overload, args, kwargs)
|
||||
return symbolic_shapes.handle_symbolic_op(func, args, kwargs)
|
||||
|
||||
func = func_overload.overloadpacket
|
||||
# We don't want to convert torch.tensor constants into tracing objects.
|
||||
if func_overload == aten.lift.default:
|
||||
return args[0]
|
||||
if func in [prim.device.default]:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if func in [prim.device]:
|
||||
return func_overload(*args, **kwargs)
|
||||
|
||||
if pytree.tree_any_only(
|
||||
torch.Tensor,
|
||||
lambda t: has_proxy_slot(t, self.tracer),
|
||||
(args, kwargs)
|
||||
):
|
||||
out = proxy_call(self, func_overload, args, kwargs)
|
||||
# When we trace through a torch.tensor invocation, you never actually
|
||||
# see a torch.ops.aten.tensor call. Instead, the way this function is
|
||||
# implemented internally is that we allocate a plain tensor (this is
|
||||
# *guaranteed* to be a plain tensor, we disable all modes when doing
|
||||
# so), and then call at::lift_fresh on it (to give modes a chance to do
|
||||
# their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed
|
||||
# to be freshly allocated, so we want lift_fresh to be a no-op (directly
|
||||
# returning the input argument).
|
||||
#
|
||||
# Here is the basic problem: when we trace this sequence of executions
|
||||
# into an FX graph, what happens to this call sequence? Traditionally,
|
||||
# tensor constants get interned as buffers on the FX GraphModule. But
|
||||
# this is dangerous. Consider:
|
||||
#
|
||||
# x = torch.tensor(1)
|
||||
# x.add_(2)
|
||||
#
|
||||
# Naively, this traces into:
|
||||
#
|
||||
# t = self._tensor_constant0 # initialized to torch.tensor(1)
|
||||
# x = torch.ops.aten.lift_fresh(t)
|
||||
# x.add_(2)
|
||||
#
|
||||
# If lift_fresh returns t directly, the subsequent add_ call will
|
||||
# modify the tensor constant. Really, the problem is we've violated
|
||||
# the invariant the the argument to lift is fresh. So what we should
|
||||
# preserve the invariant by replacing lift_fresh with lift_fresh_copy:
|
||||
#
|
||||
# t = self._tensor_constant0 # initialized to torch.tensor(1)
|
||||
# x = torch.ops.aten.lift_fresh_copy(t)
|
||||
# x.add_(2)
|
||||
#
|
||||
# This is what the overload modification does.
|
||||
else:
|
||||
flat_args = pytree.tree_flatten((args, kwargs))[0]
|
||||
handled_types = [torch.Tensor, _ProxyTensor, torch.nn.Parameter]
|
||||
|
||||
# If there are any tensor subclasses, we need to handle those tensor subclasses first
|
||||
# TODO: we could use types to test this
|
||||
if any(isinstance(arg, torch.Tensor) and type(arg) not in handled_types for arg in flat_args):
|
||||
return NotImplemented
|
||||
|
||||
if func_overload is torch.ops.aten.lift_fresh.default:
|
||||
func_overload = torch.ops.aten.lift_fresh_copy.default
|
||||
|
||||
n_args, n_kwargs = pytree.tree_map_only(SymInt, fetch_symint_proxy(self.tracer), (args, kwargs))
|
||||
|
||||
proxy_out = self.tracer.create_proxy('call_function', func_overload, n_args, n_kwargs,
|
||||
name=self.tracer.graph._target_to_str(func.__name__))
|
||||
|
||||
out = func_overload(*args, **kwargs)
|
||||
|
||||
# If this is a lift, the input tensor is guaranteed to be a
|
||||
# constant, so we keep a copy of the original argument along so
|
||||
# we can query it if we're asked to item() it at some later point
|
||||
is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default
|
||||
if is_lift and out.numel() <= CONSTANT_NUMEL_LIMIT:
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
constant = args[0].clone()
|
||||
else:
|
||||
constant = None
|
||||
track_tensor_tree(out, proxy_out, constant=constant, tracer=self.tracer)
|
||||
out = proxy_call(self, func, args, kwargs)
|
||||
|
||||
def assert_proxy_tensor(e):
|
||||
assert has_proxy_slot(e, self.tracer), \
|
||||
|
||||
Reference in New Issue
Block a user