add compile

This commit is contained in:
youkaichao
2024-07-26 19:29:36 -07:00
parent 55712941e5
commit 617fb893d5

View File

@ -787,6 +787,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
count = 0
def backend(gm, input):
nonlocal count
count += 1
print(count)
return gm.forward
self.model = torch.compile(self.model, backend=backend, fullgraph=True)
def save_sharded_state(
self,
path: str,