add compile
This commit is contained in:
@ -787,6 +787,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!")
|
||||
|
||||
count = 0
|
||||
|
||||
def backend(gm, input):
|
||||
nonlocal count
|
||||
count += 1
|
||||
print(count)
|
||||
return gm.forward
|
||||
|
||||
self.model = torch.compile(self.model, backend=backend, fullgraph=True)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
|
Reference in New Issue
Block a user