Skip to content

Commit 31d4ea9

Browse files
Fix batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,.. (#2326)
* test left-padded batched inference * demonstrate batched tex generation failure * fix remote code * fix * fix position_ids generation inside ORTModelForCausalLM class * it works until transformers 4.52 -_- * now run with latest transformers * bolean 4D mask is actually not supported by torch onnx exporter * only test generation with batched inputs, for logits are a bit off because of transformers using boolean mask * boolean mask safe softmax batched inference * style * use old typing * don't do unnecessary patching * try to avoid spamming the hub for an image * update min transformers version * better and direct torch patching * more batched generation special cases * style * initialize the il image instead of downloading it * use random pil image * test different versions of transformers in fast tests * fix * revert diffusers changes for now * mask padding kv cache as well * fix masking for old bloom * use constant image to image loading errors * style * test diffusers in series to avoid runner dying * fix * cleanup and some comments * fix and test falcon alibi * style * fix, support and test multi_query=False as well * only apply masked testing for transformers version previous to 4.39 * Update optimum/onnxruntime/modeling_decoder.py * use text decoder position ids onnx config but test its sync with list * fix opt * style * fix sdpa without overriting torch onnx exporter * use inplace op ;-; * fix st test * patch directly in onnx because patch needs to happen after softmax
1 parent 689c0b5 commit 31d4ea9

File tree

14 files changed

+515
-648
lines changed

14 files changed

+515
-648
lines changed

.github/workflows/test_onnxruntime.yml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
matrix:
2828
python-version: [3.9]
2929
runs-on: [ubuntu-22.04]
30+
transformers_version: [latest, 4.36.*, 4.45.*]
3031
test_file:
3132
[
3233
test_timm.py,
@@ -59,13 +60,26 @@ jobs:
5960
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
6061
pip install .[tests,onnxruntime] diffusers
6162
62-
- name: Test with pytest (in series)
63-
if: matrix.test_file == 'test_modeling.py'
63+
- name: Install transformers ${{ matrix.transformers-version }}
6464
run: |
65-
pytest tests/onnxruntime/test_modeling.py -m "run_in_series" --durations=0 -vvvv
65+
if [ "${{ matrix.transformers_version }}" == '4.36.*' ]; then
66+
pip install "transformers==4.36.*" "diffusers<0.32.0"
67+
elif [ "${{ matrix.transformers_version }}" == '4.45.*' ]; then
68+
pip install "transformers==4.45.*" "diffusers<0.33.0"
69+
else
70+
pip install transformers;
71+
fi
6672
6773
- name: Test with pytest (in parallel)
74+
if: matrix.test_file != 'test_diffusion.py'
75+
run: |
76+
pytest tests/onnxruntime/${{ matrix.test_file }} --durations=0 -vvvv -n auto
77+
env:
78+
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
79+
80+
- name: Test with pytest (in series)
81+
if: matrix.test_file == 'test_diffusion.py'
6882
run: |
69-
pytest tests/onnxruntime/${{ matrix.test_file }} -m "not run_in_series" --durations=0 -vvvv -n auto
83+
pytest tests/onnxruntime/${{ matrix.test_file }} --durations=0 -vvvv
7084
env:
7185
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}

.github/workflows/test_onnxruntime_slow.yml

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,15 @@ jobs:
3636
python-version: [3.9]
3737
transformers-version: [latest]
3838
runs-on: [ubuntu-22.04, windows-2022]
39-
include:
40-
- {python-version: 3.9, transformers-version: 4.36.*, runs-on: ubuntu-22.04}
41-
- {python-version: 3.9, transformers-version: 4.45.*, runs-on: ubuntu-22.04}
4239

4340
runs-on: ${{ matrix.runs-on }}
4441

4542
steps:
4643
- name: Free Disk Space (Ubuntu)
4744
if: matrix.runs-on == 'ubuntu-22.04'
4845
uses: jlumbroso/free-disk-space@main
46+
with:
47+
swap-storage: false
4948

5049
- name: Free Disk Space (macOS)
5150
if: matrix.runs-on == 'macos-15'
@@ -69,22 +68,12 @@ jobs:
6968
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
7069
pip install .[tests,onnxruntime] diffusers
7170
72-
- name: Install transformers ${{ matrix.transformers-version }}
73-
if: ${{ matrix.transformers-version == '4.36.*' }}
74-
run: |
75-
pip install "transformers==${{ matrix.transformers-version }}" "diffusers<0.32.0"
76-
77-
- name: Install transformers ${{ matrix.transformers-version }}
78-
if: ${{ matrix.transformers-version == '4.45.*' }}
79-
run: |
80-
pip install "transformers==${{ matrix.transformers-version }}" "diffusers<0.33.0"
81-
8271
- name: Test with pytest (in series)
8372
run: |
8473
pytest tests/onnxruntime -m "run_in_series" --durations=0 -vvvv
8574
env:
8675
RUN_SLOW: 1
87-
76+
8877
- name: Test with pytest (in parallel)
8978
run: |
9079
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv

optimum/exporters/onnx/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@ def __init__(
9494
def inputs(self) -> Dict[str, Dict[int, str]]:
9595
if self.use_past_in_inputs:
9696
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
97+
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
9798
self.add_past_key_values(common_inputs, direction="inputs")
98-
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
9999
else:
100100
common_inputs = {
101101
"input_ids": {0: "batch_size", 1: "sequence_length"},
102102
"attention_mask": {0: "batch_size", 1: "sequence_length"},
103103
}
104+
104105
return common_inputs
105106

106107
@property

optimum/exporters/onnx/model_configs.py

Lines changed: 20 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@
9292
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
9393
from .model_patcher import (
9494
CLIPModelPatcher,
95-
FalconModelPatcher,
9695
MgpstrModelPatcher,
97-
MistralModelPatcher,
9896
MusicgenModelPatcher,
9997
Qwen3MoeModelPatcher,
10098
SAMModelPatcher,
@@ -409,20 +407,12 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
409407
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
410408

411409

412-
# OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46
413-
if is_transformers_version(">=", "4.46.0"):
414-
415-
@register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"])
416-
class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
417-
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
418-
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
419-
420-
else:
421-
422-
@register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"])
423-
class OPTOnnxConfig(TextDecoderOnnxConfig):
424-
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
425-
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
410+
@register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"])
411+
class OPTOnnxConfig(
412+
TextDecoderWithPositionIdsOnnxConfig if is_transformers_version(">=", "4.46.0") else TextDecoderOnnxConfig
413+
):
414+
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
415+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
426416

427417

428418
@register_tasks_manager_onnx("llama", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
@@ -477,7 +467,6 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
477467
@register_tasks_manager_onnx("granite", *COMMON_TEXT_GENERATION_TASKS)
478468
class GraniteOnnxConfig(LlamaOnnxConfig):
479469
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
480-
MIN_TORCH_VERSION = version.parse("2.5.0")
481470

482471

483472
@register_tasks_manager_onnx("phi", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
@@ -502,17 +491,11 @@ class InternLM2OnnxConfig(LlamaOnnxConfig):
502491

503492
@register_tasks_manager_onnx("mistral", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
504493
class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
505-
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
506-
MIN_TRANSFORMERS_VERSION = version.parse("4.35.0")
507-
508494
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
509495
DEFAULT_ONNX_OPSET = 14
510-
DUMMY_INPUT_GENERATOR_CLASSES = (
511-
MistralDummyPastKeyValuesGenerator,
512-
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
513496
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
497+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
514498
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)
515-
_MODEL_PATCHER = MistralModelPatcher
516499

517500

518501
@register_tasks_manager_onnx("mpt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"])
@@ -556,9 +539,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
556539
"gpt_bigcode", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]
557540
)
558541
class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
559-
DUMMY_INPUT_GENERATOR_CLASSES = (
560-
GPTBigCodeDummyPastKeyValuesGenerator,
561-
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
542+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator)
562543
DEFAULT_ONNX_OPSET = 14 # GPT BigCode now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
563544
DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
564545
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode")
@@ -571,36 +552,29 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
571552
decoder_sequence_name = "past_sequence_length"
572553
name = "past_key_values"
573554
else:
574-
decoder_sequence_name = "past_sequence_length + 1"
555+
decoder_sequence_name = "past_sequence_length + sequence_length"
575556
name = "present"
576557

577558
for i in range(self._normalized_config.num_layers):
578-
# No dim for `n_head` when using multi-query attention
579-
inputs_or_outputs[f"{name}.{i}.key_value"] = {
580-
0: "batch_size",
581-
1: decoder_sequence_name,
582-
}
559+
if self._normalized_config.multi_query:
560+
# No dim for `n_head` when using multi-query attention
561+
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 1: decoder_sequence_name}
562+
else:
563+
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 2: decoder_sequence_name}
583564

584565
def flatten_past_key_values(self, flattened_output, name, idx, t):
585566
flattened_output[f"{name}.{idx}.key_value"] = t
586567

587568

588569
@register_tasks_manager_onnx("falcon", *COMMON_TEXT_GENERATION_TASKS + ["question-answering", "token-classification"])
589-
class FalconOnnxConfig(TextDecoderOnnxConfig):
590-
# This is due to the cache refactoring for Falcon in 4.36
591-
MIN_TRANSFORMERS_VERSION = version.parse("4.35.99")
570+
class FalconOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
571+
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")
592572

593-
DUMMY_INPUT_GENERATOR_CLASSES = (
594-
FalconDummyPastKeyValuesGenerator,
595-
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
573+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, FalconDummyPastKeyValuesGenerator)
596574
DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention
597575
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
598576
DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator
599577

600-
# we need to set output_attentions=True in the model input to avoid calling
601-
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
602-
_MODEL_PATCHER = FalconModelPatcher
603-
604578
def __init__(
605579
self,
606580
config: "PretrainedConfig",
@@ -634,10 +608,8 @@ def __init__(
634608
def inputs(self) -> Dict[str, Dict[int, str]]:
635609
common_inputs = super().inputs
636610

637-
if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]:
638-
# When alibi is used, position_ids are not used in Falcon.
639-
# Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116
640-
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
611+
if self._config.alibi:
612+
common_inputs.pop("position_ids", None)
641613

642614
return common_inputs
643615

@@ -836,7 +808,6 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
836808
)
837809
class BartOnnxConfig(M2M100OnnxConfig):
838810
DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
839-
MIN_TORCH_VERSION = version.parse("2.1.2")
840811

841812

842813
@register_tasks_manager_onnx(
@@ -868,7 +839,7 @@ class BigBirdPegasusOnnxConfig(BartOnnxConfig):
868839
@property
869840
def inputs(self) -> Dict[str, Dict[int, str]]:
870841
inputs = super().inputs
871-
if self._config.attention_type == "block_sparse":
842+
if self._config.attention_type == "block_sparse" and self.task != "text-generation":
872843
# BigBirdPegasusEncoder creates its own attention_mask internally
873844
# https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py#L1875
874845
inputs.pop("attention_mask", None)
@@ -888,7 +859,6 @@ class MarianOnnxConfig(BartOnnxConfig):
888859
@register_tasks_manager_onnx("vit", *["feature-extraction", "image-classification", "masked-im"])
889860
class ViTOnnxConfig(VisionOnnxConfig):
890861
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
891-
MIN_TORCH_VERSION = version.parse("1.11")
892862
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
893863

894864
@property
@@ -1574,7 +1544,6 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
15741544
# Sets the absolute tolerance to when validating the exported ONNX model against the
15751545
# reference model.
15761546
ATOL_FOR_VALIDATION = 1e-4
1577-
MIN_TORCH_VERSION = version.parse("2.1")
15781547

15791548
# needs einsum operator support, available since opset 12
15801549
DEFAULT_ONNX_OPSET = 12
@@ -1646,7 +1615,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
16461615
"layoutlmv3", *["feature-extraction", "question-answering", "text-classification", "token-classification"]
16471616
)
16481617
class LayoutLMv3OnnxConfig(TextAndVisionOnnxConfig):
1649-
MIN_TORCH_VERSION = version.parse("1.12")
16501618
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
16511619
allow_new=True,
16521620
MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings",
@@ -2570,8 +2538,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
25702538
@register_tasks_manager_onnx("sam", *["feature-extraction"])
25712539
class SamOnnxConfig(OnnxConfig):
25722540
MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0")
2573-
# Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429
2574-
MIN_TORCH_VERSION = version.parse("2.0.99")
25752541
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
25762542
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator, DummyVisionEmbeddingsGenerator)
25772543
DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9.

0 commit comments

Comments
 (0)