When the first GUESS_SIZE elements of the correct list and the myguess list are consistent, it means that all guesses have been made. At this time, the last element of the correct list should also be the correct token, so it should be added to the hits list.
https://github.com/hao-ai-lab/LookaheadDecoding/blob/9d50de4a81d1b473bfce104ace18fbbbb6dc3255/lade/decoding.py#L1068C1-L1085C88
original code
hits = [first_guess] + [0] * (GUESS_SIZE - 1)
#multi-level window is filled
#match guess tokens
if guess_tokens is not None:
guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
for eg in range(len(guess_results) // GUESS_SIZE):
egx = eg * GUESS_SIZE
correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
myguess = guess_tokens[egx:egx + GUESS_SIZE]
gg = 0
for gg in range(len(myguess)):
if myguess[gg] != correct[gg]:
break
if gg > max_hit:
max_hit = gg
max_hit_idx = eg
hits[:max_hit + 1] = correct[:max_hit + 1]
#max_hit is the length of longest accepted sequence in verification branch
Modified code
hits = [first_guess] + [0] * GUESS_SIZE
#multi-level window is filled
#match guess tokens
if guess_tokens is not None:
guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
for eg in range(len(guess_results) // GUESS_SIZE):
egx = eg * GUESS_SIZE
correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
myguess = guess_tokens[egx:egx + GUESS_SIZE]
gg = 0
while gg < len(myguess):
if myguess[gg] != correct[gg]:
break
gg += 1
if gg > max_hit:
max_hit = gg
max_hit_idx = eg
hits[:max_hit + 1] = correct[:max_hit + 1]
#max_hit is the length of longest accepted sequence in verification branch
When the first
GUESS_SIZEelements of thecorrectlist and themyguesslist are consistent, it means that all guesses have been made. At this time, the last element of thecorrectlist should also be the correct token, so it should be added to thehitslist.https://github.com/hao-ai-lab/LookaheadDecoding/blob/9d50de4a81d1b473bfce104ace18fbbbb6dc3255/lade/decoding.py#L1068C1-L1085C88
original code
Modified code