6666
6767STARTING_BATCH_SIZE = 512
6868
69+ # Thread local param
70+ torch .set_grad_enabled (False )
71+
6972
7073class TransformersModelConfig (ModelConfig ):
7174 """Configuration class for HuggingFace Transformers models.
@@ -218,12 +221,6 @@ def __init__(
218221 if config .model_parallel is False and self .config .dtype not in ["4bit" , "8bit" ]:
219222 logger .info (f"Using Data Parallelism, putting model on device { self ._device } " )
220223 self .model = self .model .to (self ._device )
221- if config .compile :
222- try :
223- logger .info ("Compiling the model" )
224- self .model .model .compile ()
225- except AttributeError as e :
226- logger .warning ("Could not compile the model because: " , e )
227224
228225 self .model_name = _simplify_name (config .model_name )
229226
@@ -410,7 +407,7 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
410407 )
411408 # model.to(self.device)
412409 model .eval ()
413- torch . set_grad_enabled ( False )
410+
414411 if self .continuous_batching :
415412 generation_config = GenerationConfig (
416413 ** self .generation_config_dict ,
@@ -497,9 +494,6 @@ def _check_continuations_start_space(self, continuation: str) -> str:
497494 continuation = continuation .lstrip ()
498495 return continuation
499496
500- def _model_call (self , inputs : torch .Tensor ) -> torch .Tensor :
501- return self .model (inputs ).logits
502-
503497 def _get_batch_size (self , max_input_length : int , override_bs : int | None , starting_batch_size : int = 512 ) -> int :
504498 if override_bs is not None :
505499 return override_bs
@@ -509,10 +503,18 @@ def _get_batch_size(self, max_input_length: int, override_bs: int | None, starti
509503 starting_batch_size = starting_batch_size
510504 ) # if OOM, then halves batch_size and tries again
511505 def forward_batch (batch_size ):
512- test_batch = torch .ones (
513- (batch_size + int (0.1 * batch_size ), max_input_length ), device = self .device
514- ).long () # We add 10% for marging :)
515- F .log_softmax (self ._model_call (test_batch ).float (), dim = - 1 ).cpu ()
506+ fake_batch , fake_output = None , None
507+ with torch .no_grad ():
508+ try :
509+ fake_batch = torch .ones ((batch_size , max_input_length ), device = self .device ).int ()
510+ fake_output = F .log_softmax (self .model (fake_batch ).logits , dim = - 1 ).cpu ()
511+ except Exception as e :
512+ for fake_item in [fake_batch , fake_output ]:
513+ if fake_item is not None :
514+ fake_item .detach ()
515+ del fake_item
516+
517+ raise e
516518 return batch_size
517519
518520 batch_size = forward_batch ()
@@ -645,10 +647,14 @@ def _padded_greedy_until(
645647 position = 0 ,
646648 disable = self .disable_tqdm ,
647649 ):
648- if split [0 ].generation_size is None :
650+ if self .generation_config_dict .get ("max_new_tokens" , None ) is not None :
651+ # The user forces a specific generation size
652+ max_context_continuation_size_allowed = self .generation_config_dict ["max_new_tokens" ]
653+ elif split [0 ].generation_size is None :
649654 # No constraints on the generation size: max length allowed is the max model context
650655 max_context_continuation_size_allowed = self .max_length
651656 else :
657+ # The task forces a specific generation size
652658 context = self .prompt_manager .prepare_prompt (split [0 ])
653659 tokenized_context = self .tokenizer (context )
654660
@@ -953,7 +959,7 @@ def _loglikelihood_tokens( # noqa: C901
953959 max_context = None , # computed as model max length in the function
954960 )
955961
956- model_output = self ._model_call (prepared_batch .input_ids )
962+ model_output = self .model (prepared_batch .input_ids ). logits
957963 logits = F .log_softmax (model_output , dim = - 1 ) # [batch, sequence_length, vocab]
958964
959965 flat_index = 0
0 commit comments