Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchtools/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def initialize_dir(self, no_git=False):
# store tasks
task_types = set([task.storage_type for task in self.tasks.values()])
if 'csv' in task_types:
os.mkdir(self.bench_path,'tasks')
os.mkdir(os.path.join(self.bench_path,'tasks'))
for task_name, task_object in self.tasks.items():
task_object.write(self.bench_path)

Expand Down
4 changes: 2 additions & 2 deletions benchtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def add_task(task_name, bench_path, task_source,task_type):
@benchtool.command()
@click.argument('benchmark-path', required = True, type=str)
@click.argument('task_name', required = True)
@click.option('-r', '--runner-type', type=click.Choice(['ollama', 'openai', 'aws']),
@click.option('-r', '--runner-type', type=click.Choice(['ollama', 'openai', 'bedrock']),
default="ollama", help="The engine that will run your LLM.")
@click.option('-m', '--model', type=str, default="gemma3",
help="The LLM to be benchmarked.")
Expand All @@ -162,7 +162,7 @@ def run_task(benchmark_path: str, task_name, runner_type, model, api_url, log_pa

@benchtool.command()
@click.argument('benchmark-path', required = False, type=str, default='.')
@click.option('-r', '--runner-type', type=click.Choice(['ollama', 'openai', 'aws']),
@click.option('-r', '--runner-type', type=click.Choice(['ollama', 'openai', 'bedrock']),
default="ollama", help="The engine that will run your LLM.")
@click.option('-m', '--model', type=str, default="gemma3",
help="The LLM to be benchmarked.")
Expand Down
19 changes: 19 additions & 0 deletions benchtools/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pydantic import BaseModel


class StringAnswer(BaseModel):
answer: str

class IntAnswer(BaseModel):
answer: int

class FloatAnswer(BaseModel):
answer: float

class StringJustification(BaseModel):
answer: str
justification: str

class IntJustification(BaseModel):
answer: int
justification: str
116 changes: 99 additions & 17 deletions benchtools/task.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
# defines a class object for a task
# from openai import OpenAI
import os
import yaml # requires pyyaml
import yaml
import json
import boto3
import pandas as pd
from ollama import chat, ChatResponse, Client
from benchtools.logger import init_log_folder, log_interaction
from .logger import init_log_folder, log_interaction
from pathlib import PurePath
from datasets import load_dataset
from benchtools.runner import BenchRunner
from .runner import BenchRunner
import sys
from .response import StringAnswer, StringJustification, IntAnswer, IntJustification

from benchtools.scorers import scoring_fx_list, contains, exact_match
from .scorers import scoring_fx_list, contains, exact_match

from .utils import concatenator_id_generator, selector_id_generator

prompt_id_fx = {'concatenator_id_generator':concatenator_id_generator,
'selector_id_generator':selector_id_generator}

class UnMatchedModel(Exception):
"""
Exception raised for a bedrock model that isn't accounted for in the match statement
Follow https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html for a list of available models on bedrock and their inferance parameters
"""
def __init__(self, model):
self.model = model
message = f"Cannot call the model ${attempted_withdrawal} using aws Bedrock. Please fetch the correct inferance parameters for it and add it in a PR to BenchTools."
super().__init__(message) # Call the base class constructor


class Task:
"""
defines a basic prompt task with a simple scoring function
"""

def __init__(self, task_name, template, reference=None, scoring_function=None,
variant_values = None, storage_type = 'yaml', description = None,
prompt_id_generator_fx = concatenator_id_generator):
prompt_id_generator_fx = concatenator_id_generator,
format='StringAnswer'):
"""
init a task object from a prompt and reference, and a scoring function. If no scoring function is provided, defaults to exact match.

Expand All @@ -47,11 +63,15 @@ def __init__(self, task_name, template, reference=None, scoring_function=None,
self.template = template
self.variant_values = variant_values
self.reference = reference

# set up to name individual prompts
if not callable(prompt_id_generator_fx):
prompt_id_generator_fx = prompt_id_fx[prompt_id_generator_fx]

self.prompt_id_generator = prompt_id_generator_fx

# setup for response format
mod = sys.modules[__name__]
self.FormatClass = getattr(mod,format)

self.storage_type = storage_type
if scoring_function:
Expand Down Expand Up @@ -90,7 +110,7 @@ def from_txt_csv(cls, source_folder, task_name = None, scoring_function = None,
# load and strip whitespace from column names
value_answer_df = pd.read_csv(values_file).rename(columns=lambda x: x.strip())

variant_values = value_answer_df.drop(columns='reference').to_dict(orient='records')
variant_values = value_answer_df.drop(columns='reference').to_dict(orient='records') # This is correct
reference = value_answer_df['reference'].tolist()

if 'id' in value_answer_df.columns:
Expand Down Expand Up @@ -119,6 +139,7 @@ def from_example(cls, task_name, storage_type):
denoted in brackets. {verb} matching ' + supplemental_files[storage_type]
variant_values = {'noun':['text','task'],
'verb':['use','select']}
variant_values = [{k:v for k,v in zip(variant_values.keys(),vals)} for vals in zip(*variant_values.values())]
description = 'give your task a short description '
return cls(task_name, template= template, variant_values = variant_values,
description = description, reference='',
Expand Down Expand Up @@ -204,11 +225,13 @@ def generate_prompts(self):
# TODO: consider if this could be a generator function if there are a lot of variants, to avoid memory issues. For now, we will assume that the number of variants is small enough to generate all prompts at once.
if self.variant_values:
id_prompt_list = []

for value_set in self.variant_values:
prompt = self.template
prompt = prompt.format(**value_set)
prompt_id = self.prompt_id_generator(self.task_id,value_set)
id_prompt_list.append((prompt_id,prompt))

return id_prompt_list
else:
return [(self.name, self.template)]
Expand Down Expand Up @@ -260,8 +283,9 @@ def write_csv(self, target_folder):
'''
write the task to a csv file with a task.txt template file
'''
if not os.path.exists(os.path.join(target_folder, self.task_id)):
os.mkdir(os.path.join(target_folder, self.task_id))
# Create task folder
os.mkdir(os.path.join(target_folder, self.task_id))

# write the template
with open(os.path.join(target_folder,self.task_id, 'template.txt'), 'w') as f:
f.write(self.template)
Expand Down Expand Up @@ -313,17 +337,21 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
print(f"Couldn't create log directory in {log_dir}...\n{e}")


for prompt_name, sub_task in self.generate_prompts():

for prompt_name, prompt in self.generate_prompts():

error = None
response = ''
try:
match runner.runner_type:
case "ollama":
completion: ChatResponse = chat(model=runner.model, messages=[
completion: ChatResponse = chat(
model=runner.model,
format = self.FormatClass.model_json_schema(),
messages=[
{
'role': 'user',
'content':sub_task,
'content':prompt,
},
])
# print("response: " + response.message.content)
Expand All @@ -334,16 +362,17 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
client = Client(
host=runner.api_url if runner.api_url else "http://localhost:11434",
)
completeion = client.chat(
completion = client.chat(
runner.model,
format = self.FormatClass.model_json_schema(),
messages=[
{
"role": "user",
"content": sub_task,
"content": prompt,
},
],
)
response = completeion["message"]["content"]
response = completion["message"]["content"]
responses.append(response)

case "openai":
Expand All @@ -355,18 +384,71 @@ def run(self, runner=BenchRunner(), log_dir='logs', benchmark=None, bench_path=N
messages=[
{
"role": "user",
"content": sub_task,
"content": prompt,
}
],
)
response = chat_completion.choices[0].message.content
responses.append(response)
case "bedrock":
bedrock_client = boto3.client('bedrock-runtime')
# Bedrock has multiple foundational models that will each differ in request parameters and response fields we included cases for a couple of them
# for available foundational models and their inferance parameters follow
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
# Catch the model family first
model_fam = None
if runner.model.startswith("meta"): model_fam = "llama"
elif runner.model.startswith("google"): model_fam = "gemma"
match model_fam:
case "llama":
# Embed the prompt in Llama 3's instruction format.
formatted_prompt = f"""
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
# Format the request payload using the model's native structure.
request = {
"prompt": formatted_prompt,
# "max_gen_len": 512,
# "temperature": 0.5,
}
# Convert the native request to JSON.
request = json.dumps(request)
completeion = bedrock_client.invoke_model(
modelId = runner.model,
body = request
)
# Decode the response body.
response = json.loads(completeion["body"].read())
response = response["generation"]
case "gemma":
completeion = bedrock_client.invoke_model(
modelId = runner.model,
body = json.dumps(
{
'messages': [
{
'role': 'user',
'content': prompt
}
]
}
)
)
# Decode the response body.
response = json.loads(completeion['body'].read())
response = response['choices'][0]['message']['content']
case _:
raise UnMatchedModel(runner.model)
responses.append(response)
case _:
print(f"Runner type {runner.runner_type} not supported")
return None
except Exception as e:
error = e
log_interaction(run_log, prompt_name, sub_task, response, str(error))
log_interaction(run_log, prompt_name, prompt, response, str(error))



Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ dependencies = [
"pyyaml",
"pandas",
"datasets",
"tabulate",
"openai",
"ollama"
"ollama",
"boto3",
"pydantic"
]
requires-python = ">=3.10"
authors = [
Expand Down