Compare commits

...

15 Commits

Author SHA1 Message Date
9c4b4efbbd try 1 2024-11-23 22:42:07 +01:00
4c86fb37ff try 1 2024-11-23 22:32:34 +01:00
0baed61a93 try 1 2024-11-23 21:50:07 +01:00
ca8e79ffc7 try 1 2024-11-23 20:59:55 +01:00
2681146a3b try 1 2024-11-23 20:44:09 +01:00
4652d9c364 try 1 2024-11-23 20:43:05 +01:00
2b961c724a try 1 2024-11-23 20:33:25 +01:00
51e8a980bc try 1 2024-11-23 20:10:48 +01:00
169e6c1cf9 try 1 2024-11-23 20:04:44 +01:00
2d885f574e try 1 2024-11-23 19:57:47 +01:00
514cd4170e try 1 2024-11-23 19:46:50 +01:00
d9b44a9f1a try 1 2024-11-23 19:33:34 +01:00
c74edfdd4c try 1 2024-11-23 19:28:17 +01:00
132b635a10 try 1 2024-11-23 19:06:17 +01:00
09cc04a36a try 1 2024-11-23 18:06:42 +01:00

View File

@ -117,6 +117,41 @@ if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
o = dict()
import queue
q = queue.Queue()
p = queue.Queue()
def model_forward_2(model, *args, **kwargs):
with torch.no_grad():
return model.forward(*args, **kwargs)
my_model = None
def foo():
while True:
item = q.get()
o, model, model_inputs, put_output = item
if o['model_forward'] is None:
#if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if model.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward_3 = torch.compile(model_forward_2, mode="reduce-overhead", fullgraph=True)
outputs = model_forward_3(my_model, return_dict=True, **model_inputs)
o['model_forward'] = model_forward_3
else:
outputs = o['model_forward'](my_model, return_dict=True, **model_inputs)
o['outputs'] = outputs
# only put if necessary!
if put_output:
p.put(o)
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
"""
@ -3225,14 +3260,17 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
def model_forward(model, *args, **kwargs):
return model.forward(*args, **kwargs)
#o = dict()
if 'model_forward' not in o:
o['model_forward'] = None
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
# q.task_done()
import threading
t = threading.Thread(target=foo)
t.start()
i = 0
while self._has_unfinished_sequences(
@ -3246,10 +3284,40 @@ class GenerationMixin:
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
if i == 0:
already_compied = False
if o['model_forward'] is not None:
already_compied = True
outputs = self(**model_inputs, return_dict=True)
i += 1
else:
outputs = model_forward(self, return_dict=True, **model_inputs)
if not already_compied:
q.put((o, self, model_inputs, False))
# use self
outputs = self(**model_inputs, return_dict=True)
else:
# directly call (??)
# outputs = o['model_forward'](self, return_dict=True, **model_inputs)
q.put((o, self, model_inputs, True))
item = p.get()
outputs = item['outputs']
# if i == 1 and o['model_forward'] is None:
# # don't join
# # just compile
# q.put((o, self, model_inputs))
# # when compiled is done
# if o['model_forward'] is not None:
# import datetime
# s = datetime.datetime.now()
# q.put((o, self, model_inputs))
# item = p.get()
# outputs = item['outputs']
# d = (datetime.datetime.now() - s).total_seconds()
# # print(d)
# else:
# outputs = self(**model_inputs, return_dict=True)
i += 1
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(