Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ pub struct StreamableParser {
stop_tokens: HashSet<Rank>,
last_content_delta: Option<String>,
undecoded_tokens: Vec<Rank>,
undecoded_bytes: Vec<u8>,
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -1068,6 +1069,7 @@ impl StreamableParser {
stop_tokens,
last_content_delta: None,
undecoded_tokens: Vec::new(),
undecoded_bytes: Vec::new(),
})
}

Expand Down Expand Up @@ -1148,14 +1150,60 @@ impl StreamableParser {
match self
.encoding
.tokenizer()
.decode_utf8(&self.undecoded_tokens)
.decode_bytes(&self.undecoded_tokens)
{
Ok(decoded) => {
content_tokens.extend(self.undecoded_tokens.iter().copied());
self.last_content_delta = Some(decoded);
Ok(decoded_bytes) => {
self.undecoded_bytes.extend(decoded_bytes.iter().copied());
match String::from_utf8(self.undecoded_bytes.clone()) {
Ok(decoded_str) => {
self.encoding
.render_text_into(&decoded_str, content_tokens)?;
self.last_content_delta = Some(decoded_str);
self.undecoded_bytes.clear();
}
Err(e) => {
let utf8_error = e.utf8_error();
let decoded_bytes = e.into_bytes();

let valid_len = utf8_error.valid_up_to();

let mut content_delta = String::new();
if valid_len > 0 {
let valid_str = String::from_utf8(
decoded_bytes[..valid_len].to_vec(),
)
.unwrap();
self.encoding
.render_text_into(&valid_str, content_tokens)?;
content_delta.push_str(&valid_str);
self.undecoded_bytes.drain(..valid_len);
}

match utf8_error.error_len() {
Some(error_len) => {
let replacement = '\u{FFFD}'.to_string();
self.encoding.render_text_into(
&replacement,
content_tokens,
)?;
content_delta.push_str(&replacement);
self.undecoded_bytes.drain(..error_len);
}
None => {
// waiting on next byte in our utf-8 sequence
self.last_content_delta = None;
}
}

if !content_delta.is_empty() {
self.last_content_delta = Some(content_delta);
}
}
}
self.undecoded_tokens.clear();
}
Err(_) => {
// Invalid bytes, so wait on the next token
self.last_content_delta = None;
}
}
Expand All @@ -1167,7 +1215,13 @@ impl StreamableParser {
true
};
if is_eos {
let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
let content_text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
let tokens_text = self
.encoding
.tokenizer()
.decode_utf8(self.undecoded_tokens.clone())?;
let bytes_text = String::from_utf8_lossy(&self.undecoded_bytes);
let text = content_text + &tokens_text + &bytes_text;
let message = Message {
author: header.author.clone(),
recipient: header.recipient.clone(),
Expand Down
156 changes: 156 additions & 0 deletions tests/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,159 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
]

assert parser.messages == expected


def test_streamable_parser_invalid_utf8_decoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)

# Confirm our token sequence is invalid utf-8
# token 9552 corresponds to the bytes [32, 240, 159]
# 32 is a space, 240,159 is an invalid utf-8 sequence
invalid_token_sequence = [9552, 9552]
with pytest.raises(HarmonyError):
encoding.decode_utf8(invalid_token_sequence)

prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
suffix_tokens = encoding.encode("worked<|end|>", allowed_special="all")
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
parser = StreamableParser(encoding, None)
for token in tokens:
parser.process(token)

expected = [
# Confirm we got the utf-8 replacement characters for the invalid sequences
# and the remaining valid utf-8 sequence
Message.from_role_and_content(Role.ASSISTANT, " \uFFFD \uFFFDworked"),
]
assert parser.messages == expected


def test_streamable_parser_invalid_utf8_decoding_split_across_tokens():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)

valid_token_sequence = encoding.encode("XY")
encoding.decode_utf8(valid_token_sequence)

# Confirm prepending specific token makes invalid utf-8
# 9552 token is the start of a multi-byte utf-8 sequence,
# which means prepending it to our previously valid sequence
# makes it invalid utf-8
invalid_token_sequence = [9552] + valid_token_sequence
with pytest.raises(HarmonyError):
encoding.decode_utf8(invalid_token_sequence)

prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
suffix_tokens = encoding.encode("<|end|>", allowed_special="all")
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
parser = StreamableParser(encoding, None)
for token in tokens:
parser.process(token)

expected = [
# One utf-8 replacement character but otherwise kept our space
# (from token 9552) and "X" and "Y" tokens
Message.from_role_and_content(Role.ASSISTANT, " \uFFFDXY"),
]
assert parser.messages == expected


def test_streamable_parser_invalid_utf8_decoding_multi_byte_token():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)

# Valid utf-8 sequence - 55=X, 56=Y in tokenizer
valid_token_sequence = encoding.encode(" interesting")
encoding.decode_utf8(valid_token_sequence)

# Confirm prepending specific token makes invalid utf-8
# 9552 token is the start of a multi-byte utf-8 sequence,
# which means prepending it to our previously valid sequence
# makes it invalid utf-8
invalid_token_sequence = [9552] + valid_token_sequence
with pytest.raises(HarmonyError):
encoding.decode_utf8(invalid_token_sequence)

prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
suffix_tokens = encoding.encode("<|end|>", allowed_special="all")
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
parser = StreamableParser(encoding, None)
for token in tokens:
parser.process(token)

expected = [
# One utf-8 replacement character and the contents of our second token,
# which maps to the text " interesting"
Message.from_role_and_content(Role.ASSISTANT, " \uFFFD interesting"),
]
assert parser.messages == expected


def test_streamable_parser_invalid_utf8_decoding_multi_byte_token_no_eos_marker():
"""Ensure we don't leave partially decoded tokens with no EOS marker."""
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)

# Valid utf-8 sequence - 55=X, 56=Y in tokenizer
valid_token_sequence = encoding.encode(" interesting")
encoding.decode_utf8(valid_token_sequence)

# Confirm prepending specific token makes invalid utf-8
# 9552 token is the start of a multi-byte utf-8 sequence,
# which means prepending it to our previously valid sequence
# makes it invalid utf-8
invalid_token_sequence = [9552] + valid_token_sequence
with pytest.raises(HarmonyError):
encoding.decode_utf8(invalid_token_sequence)

prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
suffix_tokens = encoding.encode(" story")
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
parser = StreamableParser(encoding, None)

content_deltas = []
for token in tokens:
parser.process(token)
if parser.last_content_delta is not None:
content_deltas.append(parser.last_content_delta)

# No EOS, so no full message, but make sure we have the current content
assert parser.current_content == " \uFFFD interesting story"

# Ensure all the deltas combine to form our expected content
assert "".join(content_deltas) == " \uFFFD interesting story"

# Confirm we can keep accumulating content delta and content
one_more_token = encoding.encode("Y")[0]
parser.process(one_more_token)
assert parser.last_content_delta == "Y"
assert parser.current_content == " \uFFFD interesting storyY"


def test_streamable_parser_tricky_utf8_decoding():
"""Try text with various types of utf-8 sequences that are more likely to fail."""
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)

tricky_utf8_text = (
"Hello Müller, Γειά σου, Привет, שלום, مرحبا, नमस्ते, こんにちは, 안녕하세요,"
" 你好. Normalized (naïve) vs. decomposed (naïve) characters. "
"Some emojis: 😊👋🏾👨‍👩‍👧‍👦🇺🇸."
)
valid_token_sequence = encoding.encode(tricky_utf8_text)

prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
suffix_tokens = encoding.encode("<|end|>", allowed_special="all")
tokens = prefix_tokens + valid_token_sequence + suffix_tokens
parser = StreamableParser(encoding, None)

content_deltas = []
for token in tokens:
parser.process(token)
if parser.last_content_delta is not None:
content_deltas.append(parser.last_content_delta)

expected = [
Message.from_role_and_content(Role.ASSISTANT, tricky_utf8_text),
]
# Ensure we got the entirety of our tricky utf-8 text as message content
assert parser.messages == expected

# Ensure if we're accumulating content deltas we still get the full utf-8 text
assert "".join(content_deltas) == tricky_utf8_text
Loading