Compare commits

...

1 Commits

Author SHA1 Message Date
617fb893d5 add compile 2024-07-26 19:29:36 -07:00

View File

@ -787,6 +787,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
count = 0
def backend(gm, input):
nonlocal count
count += 1
print(count)
return gm.forward
self.model = torch.compile(self.model, backend=backend, fullgraph=True)
def save_sharded_state(
self,
path: str,