Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/cli_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,19 @@ The standard deviation of the number of tokens in each output.

#### `--prompt-prefix-pool-size`, `--prefix-prompt-pool-size`, `--num-prefix-prompts` `<int>`

The total size of the prefix prompt pool to select prefixes from. If this value is not zero, these are prompts that are prepended to input prompts. This is useful for benchmarking models that use a K-V cache.
The total size of the prefix prompt pool to select prefixes from. If this value is not zero, these are prompts that are prepended to input prompts. This is useful for benchmarking models that use a K-V cache. This field cannot be used with --prefix-reuse-rate.
<br>_Default: `0`_

#### `--prompt-prefix-length`, `--prefix-prompt-length` `<int>`

The number of tokens in each prefix prompt. This is only used if "num" is greater than zero. Note that due to the prefix and user prompts being concatenated, the number of tokens in the final prompt may be off by one.
The number of tokens in each prefix prompt. This is only used if "num" is greater than zero. Note that due to the prefix and user prompts being concatenated, the number of tokens in the final prompt may be off by one. This field cannot be used with --prefix-reuse-rate.
<br>_Default: `0`_

#### `--prefix-reuse-rate` `<float>`

The portion of input sequence length (ISL) that should be from a reused prefix. This option automatically sets the prefix prompt pool size to 1. The prefix length will be calculated as: ISL * prefix_reuse_rate. For example, with ISL=1000 and prefix-reuse-rate=0.5, the prefix will be 500 tokens. Value must be between 0.0 and 1.0. Cannot be used with --prefix-prompt-pool-size. Requires --isl (input sequence length mean) to be set.
<br>_Default: `0.0`_

## Conversation Input Options

#### `--conversation-num`, `--num-conversations`, `--num-sessions` `<int>`
Expand Down
1 change: 1 addition & 0 deletions src/aiperf/common/config/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class InputTokensDefaults:
class PrefixPromptDefaults:
POOL_SIZE = 0
LENGTH = 0
PREFIX_REUSE_RATE = 0.0


@dataclass(frozen=True)
Expand Down
55 changes: 53 additions & 2 deletions src/aiperf/common/config/prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,33 @@ class PrefixPromptConfig(BaseConfig):

_CLI_GROUP = Groups.PREFIX_PROMPT

@model_validator(mode="after")
def validate_prefix_reuse_rate_conflicts(self) -> Self:
"""Validate that prefix_reuse_rate is not used with pool_size or length."""
if self.prefix_reuse_rate > 0 and self.pool_size > 0:
raise ValueError(
"Cannot use --prefix-reuse-rate with --prefix-prompt-pool-size. "
"These options are mutually exclusive."
)

if self.prefix_reuse_rate > 0 and self.length > 0:
raise ValueError(
"Cannot use --prefix-reuse-rate with --prefix-prompt-length. "
"When using --prefix-reuse-rate, the prefix length is automatically "
"calculated as: ISL * prefix_reuse_rate."
)

return self

pool_size: Annotated[
int,
Field(
ge=0,
description=(
"The total size of the prefix prompt pool to select prefixes from.\n"
"If this value is not zero, these are prompts that are prepended to input prompts.\n"
"This is useful for benchmarking models that use a K-V cache."
"This is useful for benchmarking models that use a K-V cache. "
"This field cannot be used with --prefix-reuse-rate."
),
),
CLIParameter(
Expand All @@ -151,7 +170,8 @@ class PrefixPromptConfig(BaseConfig):
"The number of tokens in each prefix prompt.\n"
'This is only used if "num" is greater than zero.\n'
"Note that due to the prefix and user prompts being concatenated,\n"
"the number of tokens in the final prompt may be off by one."
"the number of tokens in the final prompt may be off by one.\n"
"This field cannot be used with --prefix-reuse-rate."
),
),
CLIParameter(
Expand All @@ -163,6 +183,27 @@ class PrefixPromptConfig(BaseConfig):
),
] = PrefixPromptDefaults.LENGTH

prefix_reuse_rate: Annotated[
float,
Field(
ge=0.0,
le=1.0,
description=(
"The portion of input sequence length (ISL) that should be from a reused prefix.\n"
"This option automatically sets the prefix prompt pool size to 1.\n"
"The prefix length will be calculated as: ISL * prefix_reuse_rate.\n"
"For example, with ISL=1000 and prefix-reuse-rate=0.5, the prefix will be 500 tokens.\n"
"Value must be between 0.0 and 1.0.\n"
"Cannot be used with --prefix-prompt-pool-size.\n"
"Requires --isl (input sequence length mean) to be set."
),
),
CLIParameter(
name="--prefix-reuse-rate",
group=_CLI_GROUP,
),
] = PrefixPromptDefaults.PREFIX_REUSE_RATE


class PromptConfig(BaseConfig):
"""
Expand All @@ -187,6 +228,16 @@ def validate_sequence_distribution_format(self) -> Self:
raise ValueError(f"Invalid sequence distribution format: {e}") from e
return self

@model_validator(mode="after")
def validate_prefix_reuse_rate_requires_isl(self) -> Self:
"""Validate that ISL is set when prefix_reuse_rate is used."""
if self.prefix_prompt.prefix_reuse_rate > 0 and self.input_tokens.mean == 0:
raise ValueError(
"When using --prefix-reuse-rate, you must also specify "
"--isl (input sequence length mean) with a value greater than 0."
)
return self

batch_size: Annotated[
int,
Field(
Expand Down
10 changes: 9 additions & 1 deletion src/aiperf/dataset/composer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,12 @@ def _finalize_turn(self, turn: Turn) -> None:

@property
def prefix_prompt_enabled(self) -> bool:
return self.config.input.prompt.prefix_prompt.length > 0
# When prefix_reuse_rate is used, prefix is enabled if prefix_reuse_rate > 0
if self.config.input.prompt.prefix_prompt.prefix_reuse_rate > 0:
return True

# Otherwise, prefix is enabled if both length and pool_size are set
return (
self.config.input.prompt.prefix_prompt.length > 0
and self.config.input.prompt.prefix_prompt.pool_size > 0
)
26 changes: 21 additions & 5 deletions src/aiperf/dataset/composer/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,29 @@ def _generate_text_payloads(self, turn: Turn, is_first: bool) -> Text:
)

for _ in range(self.config.input.prompt.batch_size):
# Generate prompt content using the sampled input sequence length
content = self.prompt_generator.generate(mean=isl, stddev=stddev)

# Add prefix prompt if this is the first turn and prefix is enabled
# If prefix is enabled and this is the first turn
if is_first and self.prefix_prompt_enabled:
# Get the prefix
prefix = self.prompt_generator.get_random_prefix_prompt()
content = f"{prefix} {content}"

# Calculate unique content length
# Total ISL = prefix length + unique content length
prefix_length = self.prompt_generator.get_prefix_length()
unique_content_length = max(0, isl - prefix_length)

# Generate unique content with adjusted length
if unique_content_length > 0:
unique_content = self.prompt_generator.generate(
mean=unique_content_length, stddev=stddev
)
else:
unique_content = ""

# Concatenate prefix + unique content (no space to avoid extra tokens)
content = f"{prefix}{unique_content}"
else:
# Generate prompt content using the sampled input sequence length
content = self.prompt_generator.generate(mean=isl, stddev=stddev)

text.contents.append(content)

Expand Down
48 changes: 38 additions & 10 deletions src/aiperf/dataset/generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ def __init__(self, config: PromptConfig, tokenizer: Tokenizer, **kwargs):
if self._tokenized_corpus is None:
self._initialize_corpus()

# Initialize prefix prompts pool if the pool size > 0
if self.config.prefix_prompt.pool_size > 0:
self._create_prefix_prompt_pool()
# Initialize prefix prompts pool
# If prefix_reuse_rate is set, use pool_size of 1 and calculate length from ISL
# Otherwise use the configured pool_size
if self.config.prefix_prompt.prefix_reuse_rate > 0:
effective_pool_size = 1
# Calculate prefix length as: ISL * prefix_reuse_rate
effective_length = int(
self.config.input_tokens.mean
* self.config.prefix_prompt.prefix_reuse_rate
)
self._create_prefix_prompt_pool(effective_pool_size, effective_length)
elif self.config.prefix_prompt.pool_size > 0:
self._create_prefix_prompt_pool(
self.config.prefix_prompt.pool_size, self.config.prefix_prompt.length
)

def _initialize_corpus(self) -> None:
"""Load and tokenize the corpus once, storing it for reuse.
Expand Down Expand Up @@ -117,17 +129,19 @@ def tokenize_chunk(chunk):
f"from {len(chunks)} chunks using {num_threads} threads"
)

def _create_prefix_prompt_pool(self) -> None:
"""Generate a pool of prefix prompts to sample from."""
def _create_prefix_prompt_pool(self, pool_size: int, length: int) -> None:
"""Generate a pool of prefix prompts to sample from.

Args:
pool_size: The size of the prefix prompt pool to create.
length: The length (in tokens) of each prefix prompt.
"""
if self._tokenized_corpus is None:
raise NotInitializedError("Tokenized corpus is not initialized.")

self._prefix_prompts = [
self._generate_prompt(self.config.prefix_prompt.length)
for _ in range(self.config.prefix_prompt.pool_size)
]
self._prefix_prompts = [self._generate_prompt(length) for _ in range(pool_size)]
self.debug(
lambda: f"Initialized prefix prompts pool with {len(self._prefix_prompts)} prompts"
lambda: f"Initialized prefix prompts pool with {len(self._prefix_prompts)} prompts of length {length}"
)

def generate(
Expand Down Expand Up @@ -269,3 +283,17 @@ def get_random_prefix_prompt(self) -> str:
"Please ensure that the prefix prompts pool is initialized."
)
return self._prefix_rng.choice(self._prefix_prompts)

def get_prefix_length(self) -> int:
"""
Get the calculated prefix length based on prefix_reuse_rate and ISL mean.

Returns:
The prefix length in tokens. Returns 0 if prefix_reuse_rate is not set.
"""
if self.config.prefix_prompt.prefix_reuse_rate > 0:
return int(
self.config.input_tokens.mean
* self.config.prefix_prompt.prefix_reuse_rate
)
return 0
108 changes: 108 additions & 0 deletions tests/unit/common/config/test_prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_prefix_prompt_config_defaults():
config = PrefixPromptConfig()
assert config.pool_size == PrefixPromptDefaults.POOL_SIZE
assert config.length == PrefixPromptDefaults.LENGTH
assert config.prefix_reuse_rate == PrefixPromptDefaults.PREFIX_REUSE_RATE


def test_prefix_prompt_config_custom_values():
Expand Down Expand Up @@ -161,3 +162,110 @@ def test_prompt_config_sequence_distribution_none_handling():
config = PromptConfig(sequence_distribution=None)
assert config.sequence_distribution is None
assert config.get_sequence_distribution() is None


def test_prefix_prompt_config_prefix_reuse_rate_custom_value():
"""
Test the PrefixPromptConfig class with a custom prefix_reuse_rate value.
"""
config = PrefixPromptConfig(prefix_reuse_rate=0.5)
assert config.prefix_reuse_rate == 0.5
assert config.pool_size == 0 # pool_size should remain 0


def test_prefix_prompt_config_prefix_reuse_rate_conflicts_with_pool_size():
"""
Test that using prefix_reuse_rate with pool_size raises a validation error.
"""
with pytest.raises(
ValueError,
match="Cannot use --prefix-reuse-rate with --prefix-prompt-pool-size",
):
PrefixPromptConfig(prefix_reuse_rate=0.5, pool_size=5)


def test_prefix_prompt_config_prefix_reuse_rate_zero_with_pool_size():
"""
Test that prefix_reuse_rate=0 is allowed with pool_size > 0 (no conflict).
"""
config = PrefixPromptConfig(prefix_reuse_rate=0.0, pool_size=5, length=100)
assert config.prefix_reuse_rate == 0.0
assert config.pool_size == 5


def test_prefix_prompt_config_pool_size_zero_with_prefix_reuse_rate():
"""
Test that pool_size=0 is allowed with prefix_reuse_rate > 0 (no conflict).
"""
config = PrefixPromptConfig(prefix_reuse_rate=0.5, pool_size=0)
assert config.prefix_reuse_rate == 0.5
assert config.pool_size == 0


def test_prefix_prompt_config_prefix_reuse_rate_bounds():
"""
Test that prefix_reuse_rate is bounded between 0.0 and 1.0.
"""
# Valid values
config = PrefixPromptConfig(prefix_reuse_rate=0.0)
assert config.prefix_reuse_rate == 0.0

config = PrefixPromptConfig(prefix_reuse_rate=1.0)
assert config.prefix_reuse_rate == 1.0

# Invalid values
with pytest.raises(ValueError):
PrefixPromptConfig(prefix_reuse_rate=-0.1)

with pytest.raises(ValueError):
PrefixPromptConfig(prefix_reuse_rate=1.1)


def test_prompt_config_prefix_reuse_rate_requires_isl():
"""
Test that prefix_reuse_rate requires ISL (input_tokens.mean) to be set.
"""
# Should fail when prefix_reuse_rate > 0 but ISL mean = 0
config = PromptConfig()
config.input_tokens.mean = 0
config.prefix_prompt.prefix_reuse_rate = 0.5

with pytest.raises(
ValueError, match="When using --prefix-reuse-rate, you must also specify --isl"
):
config.model_validate(config.model_dump())

# Should succeed when both are set
config = PromptConfig()
config.input_tokens.mean = 1000
config.prefix_prompt.prefix_reuse_rate = 0.5
validated = config.model_validate(config.model_dump())
assert validated.prefix_prompt.prefix_reuse_rate == 0.5
assert validated.input_tokens.mean == 1000

# Should succeed when prefix_reuse_rate = 0 (feature disabled)
config = PromptConfig()
config.input_tokens.mean = 0
config.prefix_prompt.prefix_reuse_rate = 0.0
validated = config.model_validate(config.model_dump())
assert validated.prefix_prompt.prefix_reuse_rate == 0.0
assert validated.input_tokens.mean == 0


def test_prefix_prompt_config_prefix_reuse_rate_conflicts_with_length():
"""
Test that using prefix_reuse_rate with prefix_prompt_length raises a validation error.
"""
with pytest.raises(
ValueError, match="Cannot use --prefix-reuse-rate with --prefix-prompt-length"
):
PrefixPromptConfig(prefix_reuse_rate=0.5, length=100)


def test_prefix_prompt_config_prefix_reuse_rate_zero_with_length():
"""
Test that prefix_reuse_rate=0 is allowed with length > 0 (no conflict).
"""
config = PrefixPromptConfig(prefix_reuse_rate=0.0, length=100)
assert config.prefix_reuse_rate == 0.0
assert config.length == 100
Loading
Loading