Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions fast_llm/data/preprocessing/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class TokenizerConfig(PreprocessingConfig):
desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.",
hint=FieldHint.core,
)
allow_no_bos: bool = Field(
default=False,
desc="Allow the tokenizer to not have a BOS token. Set to True for tokenizers without BOS (e.g. Qwen).",
hint=FieldHint.core,
)
max_vocab_size: int | None = Field(
default=None,
desc="Constrain output tokens to a specific range. Used for testing.",
Expand Down Expand Up @@ -63,8 +68,8 @@ def __init__(self, config: ConfigType):
self.tokenizer.bos_token = self._config.bos_token
if self.tokenizer.eos_token_id is None:
raise ValueError("Tokenizer does not have an EOS token.")
if self.tokenizer.bos_token_id is None:
raise ValueError("Tokenizer does not have an BOS token.")
if self.tokenizer.bos_token_id is None and not self._config.allow_no_bos:
raise ValueError("Tokenizer does not have a BOS token. Set allow_no_bos=True to allow this.")
self.eod_id = self.tokenizer.eos_token_id
self.bod_id = self.tokenizer.bos_token_id

Expand All @@ -89,9 +94,9 @@ def tokenize(
import torch

tokens = self.tokenizer.encode(text, add_special_tokens=False)
if begin:
if (begin and self.bod_id is not None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems confusing and redundant with the begin and end arguments which are the intended method for disabling the addition of bos/eos.
My suggestion: Add an explicit add_document_breaks parameter in the dataset preparator instead, and use it to set the begin and end arguments. Raise an error on missing bos/eos only if it's requested.

tokens.insert(0, self.bod_id)
if end:
if (end and self.eod_id is not None):
tokens.append(self.eod_id)

if self._config.max_vocab_size is not None:
Expand Down Expand Up @@ -271,10 +276,10 @@ def tokenize_chat(
# Prepend BOS / append EOS if not already present anywhere in the sequence.
# We check anywhere (not just first/last) because some chat templates add trailing
# whitespace after the final EOS token, e.g. "<|im_end|>\n".
prepend_bos = begin and self.bod_id not in tokens
append_eos = end and self.eod_id not in tokens
prepend_bos = begin and self.bod_id is not None and self.bod_id not in tokens
append_eos = end and self.eod_id is not None and self.eod_id not in tokens
tokens = [self.bod_id] * prepend_bos + list(tokens) + [self.eod_id] * append_eos
train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [False] * append_eos
train_mask = [False] * prepend_bos + [bool(m) for m in train_mask] + [True] * append_eos

# Convert boolean train mask to loss masking spans (spans where train_mask[i] == False)
loss_masking_spans = _train_mask_to_loss_spans(train_mask)
Expand Down
22 changes: 11 additions & 11 deletions tests/data/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def test_validate_chat_template_with_markers(common_tokenizer):
("messages", "expected_tokens", "expected_loss_masking_spans"),
(
# Single turn: full assistant turn (<assistant>Hello</assistant>) is trainable
# 15 tokens, trainable indices 7-13, loss mask spans cover 0-6 and 14
# 15 tokens, trainable indices 7-14, loss mask spans cover 0-6
(
[{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}],
[49152, 27, 789, 29, 16946, 750, 789, 2293, 17822, 29, 7371, 750, 17822, 29, 49152],
[(0, 7), (14, 15)],
[(0, 7)],
),
# Multi-turn: both assistant turns are fully trainable
# 27 tokens, trainable indices 7-13 and 19-25
# 27 tokens, trainable indices 7-13 and 19-26
(
[
{"role": "user", "content": "A"},
Expand Down Expand Up @@ -123,10 +123,10 @@ def test_validate_chat_template_with_markers(common_tokenizer):
29,
49152,
],
[(0, 7), (14, 19), (26, 27)],
[(0, 7), (14, 19)],
),
# System + user + assistant: full assistant turn trainable
# 23 tokens, trainable indices 15-21
# 23 tokens, trainable indices 15-22
(
[
{"role": "system", "content": "You are helpful."},
Expand Down Expand Up @@ -158,17 +158,17 @@ def test_validate_chat_template_with_markers(common_tokenizer):
29,
49152,
],
[(0, 15), (22, 23)],
[(0, 15)],
),
# User only: no trainable tokens
# 9 tokens, no trainable indices
# User only: no trainable tokens except EOS
# 9 tokens, trainable index 8 (EOS)
(
[{"role": "user", "content": "Hi"}],
[49152, 27, 789, 29, 16946, 750, 789, 29, 49152],
[(0, 9)],
[(0, 8)],
),
# Long multi-turn (85 tokens, 3 assistant responses with tags, tests span machinery)
# Trainable: indices 27-40, 49-62, 70-83
# Trainable: indices 27-40, 49-62, 70-84
(
[
{"role": "system", "content": "You are a helpful assistant that answers questions."},
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_validate_chat_template_with_markers(common_tokenizer):
29,
49152,
],
[(0, 27), (41, 49), (63, 70), (84, 85)],
[(0, 27), (41, 49), (63, 70)],
),
),
)
Expand Down