mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.5.5
...
torch_dyna
Author | SHA1 | Date | |
---|---|---|---|
617fb893d5 |
@ -787,6 +787,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
"provided. Defaulting to scaling factors of 1.0. "
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
"This may lead to less accurate results!")
|
"This may lead to less accurate results!")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
def backend(gm, input):
|
||||||
|
nonlocal count
|
||||||
|
count += 1
|
||||||
|
print(count)
|
||||||
|
return gm.forward
|
||||||
|
|
||||||
|
self.model = torch.compile(self.model, backend=backend, fullgraph=True)
|
||||||
|
|
||||||
def save_sharded_state(
|
def save_sharded_state(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
|
Reference in New Issue
Block a user