mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PT2] Resolve PT2 compatility issue in slice and diff (#133740)
Summary: # context * when running an IG FM training with PT2 we found there are a few graph break due to torch.diff call in [jagged_tensor.py](https://fburl.com/code/cwssxabc) ``` _length: List[int] = ( _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) if variable_stride_per_key else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist() ) ``` * look into the failure, we found the TORCH_CHECK in diff should be TORCH_SYM_CHECK * slice_forward error: df3d7729e, [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html) ``` RestartAnalysis Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0). (Size-like symbols: u38, u37) ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False. Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance. Potential framework code culprit (scroll up for full backtrace): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward if end_val < 0: ``` * after this diff: [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpAhv2Sh/failures_and_restarts.html) Test Plan: # command * run model ``` TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2 ``` * generate tlparse ``` tlparse `ls -t /var/tmp/tt/* | head -1` ``` Reviewed By: ezyang Differential Revision: D56339251 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133740 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd89bf77c8
commit
d5f6d68d68
@ -892,8 +892,11 @@ static inline void diff_check_compatible_shape(const Tensor& self, const std::op
|
||||
"diff expects prepend or append to be the same dimension as input");
|
||||
|
||||
for (const auto i : c10::irange(other.value().dim())) {
|
||||
TORCH_CHECK(
|
||||
other.value().sym_size(i) == self.sym_size(i) || i == wrapped_dim,
|
||||
if (i == wrapped_dim) {
|
||||
continue;
|
||||
}
|
||||
TORCH_SYM_CHECK(
|
||||
other.value().sym_size(i).sym_eq(self.sym_size(i)),
|
||||
"diff expects the shape of tensor to prepend or append to match that of"
|
||||
" input except along the differencing dimension;"
|
||||
" input.size(", i, ") = ", self.sym_size(i), ", but got"
|
||||
|
@ -765,18 +765,18 @@ def slice_forward(
|
||||
start_val = start if start is not None else 0
|
||||
end_val = end if end is not None else sys.maxsize # 2^63 - 1
|
||||
|
||||
if start_val < 0:
|
||||
if guard_size_oblivious(start_val < 0):
|
||||
start_val += sizes[dim]
|
||||
|
||||
if end_val < 0:
|
||||
if guard_size_oblivious(end_val < 0):
|
||||
end_val += sizes[dim]
|
||||
|
||||
if start_val < 0:
|
||||
if guard_size_oblivious(start_val < 0):
|
||||
start_val = 0
|
||||
elif start_val > sizes[dim]:
|
||||
elif guard_size_oblivious(start_val > sizes[dim]):
|
||||
start_val = sizes[dim]
|
||||
|
||||
if end_val < start_val:
|
||||
if guard_size_oblivious(end_val < start_val):
|
||||
end_val = start_val
|
||||
elif statically_known_true(end_val == sys.maxsize) or guard_size_oblivious(
|
||||
end_val > sizes[dim]
|
||||
|
Reference in New Issue
Block a user