mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
skip non memory deps in memory estimator (#164294)
Differential Revision: [D83601030](https://our.internmc.facebook.com/intern/diff/D83601030) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164294 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
10a005e87f
commit
e0f118585f
@ -290,7 +290,8 @@ class MemoryEstimator:
|
||||
)
|
||||
|
||||
for dep in rw._reads:
|
||||
assert isinstance(dep, MemoryDep)
|
||||
if not isinstance(dep, MemoryDep):
|
||||
continue
|
||||
dep = dep.simplify_with_ranges()
|
||||
if not self.persistent.writes.get(dep.name): # cache miss?
|
||||
self.persistent.reads[dep.name].add(dep)
|
||||
@ -308,7 +309,8 @@ class MemoryEstimator:
|
||||
self.must_keep_buffers.add(dep.name)
|
||||
|
||||
for dep in rw._writes:
|
||||
assert isinstance(dep, MemoryDep)
|
||||
if not isinstance(dep, MemoryDep):
|
||||
continue
|
||||
dep = dep.simplify_with_ranges()
|
||||
self.store_buffer_names.add(dep.name)
|
||||
self.persistent.writes[dep.name].add(dep)
|
||||
|
Reference in New Issue
Block a user