9292from .constants import ONNX_DECODER_MERGED_NAME , ONNX_DECODER_NAME , ONNX_DECODER_WITH_PAST_NAME
9393from .model_patcher import (
9494 CLIPModelPatcher ,
95- FalconModelPatcher ,
9695 MgpstrModelPatcher ,
97- MistralModelPatcher ,
9896 MusicgenModelPatcher ,
9997 Qwen3MoeModelPatcher ,
10098 SAMModelPatcher ,
@@ -409,20 +407,12 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
409407 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
410408
411409
412- # OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46
413- if is_transformers_version (">=" , "4.46.0" ):
414-
415- @register_tasks_manager_onnx ("opt" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" , "question-answering" ])
416- class OPTOnnxConfig (TextDecoderWithPositionIdsOnnxConfig ):
417- DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
418- NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
419-
420- else :
421-
422- @register_tasks_manager_onnx ("opt" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" , "question-answering" ])
423- class OPTOnnxConfig (TextDecoderOnnxConfig ):
424- DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
425- NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
410+ @register_tasks_manager_onnx ("opt" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" , "question-answering" ])
411+ class OPTOnnxConfig (
412+ TextDecoderWithPositionIdsOnnxConfig if is_transformers_version (">=" , "4.46.0" ) else TextDecoderOnnxConfig
413+ ):
414+ DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
415+ NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
426416
427417
428418@register_tasks_manager_onnx ("llama" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" ])
@@ -477,7 +467,6 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
477467@register_tasks_manager_onnx ("granite" , * COMMON_TEXT_GENERATION_TASKS )
478468class GraniteOnnxConfig (LlamaOnnxConfig ):
479469 MIN_TRANSFORMERS_VERSION = version .parse ("4.45.0" )
480- MIN_TORCH_VERSION = version .parse ("2.5.0" )
481470
482471
483472@register_tasks_manager_onnx ("phi" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" ])
@@ -502,17 +491,11 @@ class InternLM2OnnxConfig(LlamaOnnxConfig):
502491
503492@register_tasks_manager_onnx ("mistral" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" ])
504493class MistralOnnxConfig (TextDecoderWithPositionIdsOnnxConfig ):
505- # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
506- MIN_TRANSFORMERS_VERSION = version .parse ("4.35.0" )
507-
508494 # The ONNX export of this architecture needs the Trilu operator support, available since opset 14
509495 DEFAULT_ONNX_OPSET = 14
510- DUMMY_INPUT_GENERATOR_CLASSES = (
511- MistralDummyPastKeyValuesGenerator ,
512- ) + TextDecoderOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES
513496 DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
497+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , MistralDummyPastKeyValuesGenerator )
514498 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (num_key_value_heads = "num_key_value_heads" , allow_new = True )
515- _MODEL_PATCHER = MistralModelPatcher
516499
517500
518501@register_tasks_manager_onnx ("mpt" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" , "token-classification" ])
@@ -556,9 +539,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
556539 "gpt_bigcode" , * COMMON_TEXT_GENERATION_TASKS + ["text-classification" , "token-classification" ]
557540)
558541class GPTBigCodeOnnxConfig (TextDecoderWithPositionIdsOnnxConfig ):
559- DUMMY_INPUT_GENERATOR_CLASSES = (
560- GPTBigCodeDummyPastKeyValuesGenerator ,
561- ) + TextDecoderOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES
542+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , GPTBigCodeDummyPastKeyValuesGenerator )
562543 DEFAULT_ONNX_OPSET = 14 # GPT BigCode now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
563544 DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
564545 NORMALIZED_CONFIG_CLASS = NormalizedConfigManager .get_normalized_config_class ("gpt_bigcode" )
@@ -571,36 +552,29 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
571552 decoder_sequence_name = "past_sequence_length"
572553 name = "past_key_values"
573554 else :
574- decoder_sequence_name = "past_sequence_length + 1 "
555+ decoder_sequence_name = "past_sequence_length + sequence_length "
575556 name = "present"
576557
577558 for i in range (self ._normalized_config .num_layers ):
578- # No dim for `n_head` when using multi-query attention
579- inputs_or_outputs [ f" { name } . { i } .key_value" ] = {
580- 0 : "batch_size" ,
581- 1 : decoder_sequence_name ,
582- }
559+ if self . _normalized_config . multi_query :
560+ # No dim for `n_head` when using multi-query attention
561+ inputs_or_outputs [ f" { name } . { i } .key_value" ] = { 0 : "batch_size" , 1 : decoder_sequence_name }
562+ else :
563+ inputs_or_outputs [ f" { name } . { i } .key_value" ] = { 0 : "batch_size" , 2 : decoder_sequence_name }
583564
584565 def flatten_past_key_values (self , flattened_output , name , idx , t ):
585566 flattened_output [f"{ name } .{ idx } .key_value" ] = t
586567
587568
588569@register_tasks_manager_onnx ("falcon" , * COMMON_TEXT_GENERATION_TASKS + ["question-answering" , "token-classification" ])
589- class FalconOnnxConfig (TextDecoderOnnxConfig ):
590- # This is due to the cache refactoring for Falcon in 4.36
591- MIN_TRANSFORMERS_VERSION = version .parse ("4.35.99" )
570+ class FalconOnnxConfig (TextDecoderWithPositionIdsOnnxConfig ):
571+ MIN_TRANSFORMERS_VERSION = version .parse ("4.36.0" )
592572
593- DUMMY_INPUT_GENERATOR_CLASSES = (
594- FalconDummyPastKeyValuesGenerator ,
595- ) + TextDecoderOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES
573+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , FalconDummyPastKeyValuesGenerator )
596574 DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention
597575 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
598576 DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator
599577
600- # we need to set output_attentions=True in the model input to avoid calling
601- # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
602- _MODEL_PATCHER = FalconModelPatcher
603-
604578 def __init__ (
605579 self ,
606580 config : "PretrainedConfig" ,
@@ -634,10 +608,8 @@ def __init__(
634608 def inputs (self ) -> Dict [str , Dict [int , str ]]:
635609 common_inputs = super ().inputs
636610
637- if not self .legacy and not self ._config .alibi and self .task in ["text-generation" , "feature-extraction" ]:
638- # When alibi is used, position_ids are not used in Falcon.
639- # Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116
640- common_inputs ["position_ids" ] = {0 : "batch_size" , 1 : "sequence_length" }
611+ if self ._config .alibi :
612+ common_inputs .pop ("position_ids" , None )
641613
642614 return common_inputs
643615
@@ -836,7 +808,6 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
836808)
837809class BartOnnxConfig (M2M100OnnxConfig ):
838810 DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
839- MIN_TORCH_VERSION = version .parse ("2.1.2" )
840811
841812
842813@register_tasks_manager_onnx (
@@ -868,7 +839,7 @@ class BigBirdPegasusOnnxConfig(BartOnnxConfig):
868839 @property
869840 def inputs (self ) -> Dict [str , Dict [int , str ]]:
870841 inputs = super ().inputs
871- if self ._config .attention_type == "block_sparse" :
842+ if self ._config .attention_type == "block_sparse" and self . task != "text-generation" :
872843 # BigBirdPegasusEncoder creates its own attention_mask internally
873844 # https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py#L1875
874845 inputs .pop ("attention_mask" , None )
@@ -888,7 +859,6 @@ class MarianOnnxConfig(BartOnnxConfig):
888859@register_tasks_manager_onnx ("vit" , * ["feature-extraction" , "image-classification" , "masked-im" ])
889860class ViTOnnxConfig (VisionOnnxConfig ):
890861 NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
891- MIN_TORCH_VERSION = version .parse ("1.11" )
892862 DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
893863
894864 @property
@@ -1574,7 +1544,6 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
15741544 # Sets the absolute tolerance to when validating the exported ONNX model against the
15751545 # reference model.
15761546 ATOL_FOR_VALIDATION = 1e-4
1577- MIN_TORCH_VERSION = version .parse ("2.1" )
15781547
15791548 # needs einsum operator support, available since opset 12
15801549 DEFAULT_ONNX_OPSET = 12
@@ -1646,7 +1615,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
16461615 "layoutlmv3" , * ["feature-extraction" , "question-answering" , "text-classification" , "token-classification" ]
16471616)
16481617class LayoutLMv3OnnxConfig (TextAndVisionOnnxConfig ):
1649- MIN_TORCH_VERSION = version .parse ("1.12" )
16501618 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (
16511619 allow_new = True ,
16521620 MAX_2D_POSITION_EMBEDDINGS = "max_2d_position_embeddings" ,
@@ -2570,8 +2538,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
25702538@register_tasks_manager_onnx ("sam" , * ["feature-extraction" ])
25712539class SamOnnxConfig (OnnxConfig ):
25722540 MIN_TRANSFORMERS_VERSION = version .parse ("4.29.0.dev0" )
2573- # Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429
2574- MIN_TORCH_VERSION = version .parse ("2.0.99" )
25752541 NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
25762542 DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator , DummyPointsGenerator , DummyVisionEmbeddingsGenerator )
25772543 DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9.
0 commit comments