skip non memory deps in memory estimator (#164294)

Differential Revision: [D83601030](https://our.internmc.facebook.com/intern/diff/D83601030)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164294
Approved by: https://github.com/mlazos
This commit is contained in:
eellison
2025-09-30 14:05:38 -07:00
committed by PyTorch MergeBot
parent 10a005e87f
commit e0f118585f

View File

@ -290,7 +290,8 @@ class MemoryEstimator:
)
for dep in rw._reads:
assert isinstance(dep, MemoryDep)
if not isinstance(dep, MemoryDep):
continue
dep = dep.simplify_with_ranges()
if not self.persistent.writes.get(dep.name): # cache miss?
self.persistent.reads[dep.name].add(dep)
@ -308,7 +309,8 @@ class MemoryEstimator:
self.must_keep_buffers.add(dep.name)
for dep in rw._writes:
assert isinstance(dep, MemoryDep)
if not isinstance(dep, MemoryDep):
continue
dep = dep.simplify_with_ranges()
self.store_buffer_names.add(dep.name)
self.persistent.writes[dep.name].add(dep)