Skip to content

Better support for prompt_edit_token_weights parsing #22

@mix1009

Description

@mix1009

Instead of counting indices for tokens to pass into prompt_edit_token_weights, it would be easier to reference it by 'word'.
parse_edit_weights converts weights with words and word list, in addition to int indices to weights with int indices:

prompt = 'the quick brown fox jumps over the lazy dog'
parse_edit_weights(prompt, None, [('brown', -1), (2, 0.5), (['lazy', 'dog'], -1.5)])

returned result is [(3, -1), (2, 0.5), (8, -1.5), (9, -1.5)].

Here's the code:

def sep_token(prompt):
    tokens = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True).input_ids[0]
    words = []
    index = 1
    while True:
        word = clip_tokenizer.decode(tokens[index:index+1])
        if not word: break
        if word == '<|endoftext|>': break
        words.append(word)
        index += 1
        if index > 500: break
    return words

def parse_edit_weights(prompt, prompt_edit, edit_weights):
    if prompt_edit:
        tokens = sep_token(prompt_edit)
    else:
        tokens = sep_token(prompt)
    
    prompt_edit_token_weights=[]
    for tl, w in edit_weights:
        if isinstance(tl, list) or isinstance(tl, tuple):
            pass
        else:
            tl = [tl]
        for t in tl:
            try:
                if isinstance(t, str):
                    idx = tokens.index(t) + 1
                elif isinstance(t, int):
                    idx = t
                prompt_edit_token_weights.append((idx, w))
            except ValueError as e:
                print(f'error {e}')
            
    return prompt_edit_token_weights

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions