From 92dd3ac669eb5fe8838adef850084e7eafb49b02 Mon Sep 17 00:00:00 2001 From: codestory Date: Wed, 29 Jan 2025 18:45:14 +0000 Subject: [PATCH 1/3] feat: sync local changes --- runners/anthropic_runner.py | 130 ++++++----------- runners/api_runner.py | 146 +++++++++---------- runners/base_runner.py | 115 +++++++++++++++ runners/bedrock_runner.py | 226 +++++++++++++++--------------- runners/deepseek_runner.py | 232 ++++++++++++++++-------------- runners/gemini_runner.py | 220 ++++++++++++----------------- runners/hf_runner.py | 136 ++++++++++-------- runners/llama_cpp_runner.py | 174 +++++++++++++---------- runners/mistral_runner.py | 272 ++++++++++++++++-------------------- runners/mlx_runner.py | 174 +++++++++++++---------- runners/openai_runner.py | 238 ++++++++++++------------------- runners/together_runner.py | 231 ++++++++++++++++-------------- runners/vllm_runner.py | 134 ++++++++++-------- 13 files changed, 1243 insertions(+), 1185 deletions(-) create mode 100644 runners/base_runner.py diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 2081afb..f9b35cf 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -128,7 +128,7 @@ def process_row(row, model_name, args): def run_anthropic_eval(args): - # get params from args + """Run evaluation using Anthropic""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file @@ -145,97 +145,55 @@ def run_anthropic_eval(args): print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" ) - question_query_df = prepare_questions_df( + df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - futures = [] - for row in input_rows: - generated_query_fut = executor.submit( - process_row, - row=row, - model_name=args.model, - args=args, - ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - elif "TIMEOUT" in err: - row["timeout"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - try: - is_correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - question=row["question"], - query_category=row["query_category"], - decimal_points=args.decimal_points, - ) - if is_correct: - total_correct += 1 - row["is_correct"] = 1 - row["error_msg"] = "" - else: - row["is_correct"] = 0 - row["error_msg"] = "INCORRECT RESULTS" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"EXECUTION ERROR: {str(e)}" - output_rows.append(row) - pbar.set_description( - f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" - ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - # save results to csv + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("is_correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") - # get average rate of correct results - avg_subset = output_df["is_correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="anthropic", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="anthropic", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/api_runner.py b/runners/api_runner.py index 6b2afdf..fb43299 100644 --- a/runners/api_runner.py +++ b/runners/api_runner.py @@ -206,34 +206,30 @@ def process_row( def run_api_eval(args): - # get params from args + """Run evaluation using API""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - api_url = args.api_url - api_type = args.api_type - output_file_list = args.output_file k_shot = args.k_shot - num_beams = args.num_beams - max_workers = args.parallel_threads + cot_table_alias = args.cot_table_alias db_type = args.db_type - decimal_points = args.decimal_points logprobs = args.logprobs - cot_table_alias = args.cot_table_alias - sql_lora_path = args.adapter if args.adapter else None - sql_lora_name = args.adapter_name if args.adapter_name else None - run_name = args.run_name if args.run_name else None + run_name = getattr(args, 'run_name', None) + sql_lora_path = getattr(args, 'adapter', None) + if sql_lora_path: print("Using LoRA adapter at:", sql_lora_path) + + # Logprobs visualization directory handling if logprobs: - # check that the eval-visualizer/public directory exists if not os.path.exists("./eval-visualizer"): - # thorow error raise Exception( - "The eval-visualizer directory does not exist. Please clone it with `git clone https://github.com/defog-ai/eval-visualizer/` before running sql-eval with the --logprobs flag." + "The eval-visualizer directory does not exist. Please clone it with " + "`git clone https://github.com/defog-ai/eval-visualizer/` before running " + "sql-eval with the --logprobs flag." ) - if not os.path.exists("./eval-visualizer/public"): os.makedirs("./eval-visualizer/public") @@ -241,7 +237,6 @@ def run_api_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -249,7 +244,8 @@ def run_api_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -262,65 +258,30 @@ def run_api_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, - row["table_aliases"], + row.get("table_aliases", ""), ), axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append( - executor.submit( - process_row, - row, - api_url, - api_type, - num_beams, - decimal_points, - logprobs, - sql_lora_path, - sql_lora_name, - ) - ) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.api_url, process_row, args + ) output_df = pd.DataFrame(output_rows) - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - results = output_df.to_dict("records") + # Handle logprobs visualization if logprobs: + results = output_df.to_dict("records") print( f"Writing logprobs to JSON file at eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}" ) @@ -330,27 +291,48 @@ def run_api_eval(args): ) as f: json.dump(results, f) - del output_df["prompt"] + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Clean up and save results + if "prompt" in output_df.columns: + del output_df["prompt"] + + output_dir = os.path.dirname(output_file) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - # upload results - # with open(prompt_file, "r") as f: - # prompt = f.read() - - if args.run_name is None: + # Handle run naming and result upload + if run_name is None: run_name = output_file.split("/")[-1].replace(".csv", "") - print( - "Run name not provided. Using a output filename for run name:", run_name - ) + print("Run name not provided. Using output filename for run name:", run_name) - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - args=args, - run_name=run_name, - ) + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + try: + if hasattr(args, 'upload_url') and args.upload_url: + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="api_runner", + args=args, + run_name=run_name, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/base_runner.py b/runners/base_runner.py new file mode 100644 index 0000000..0f535fd --- /dev/null +++ b/runners/base_runner.py @@ -0,0 +1,115 @@ +import json +from time import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pandas as pd +import sqlparse +from tqdm import tqdm + +from eval.eval import compare_query_results +from utils.creds import db_creds_all +from utils.dialects import convert_postgres_ddl_to_dialect +from utils.gen_prompt import to_prompt_schema +from utils.questions import prepare_questions_df +from utils.reporting import upload_results + + +def generate_base_prompt( + prompt_file, + question, + db_name, + db_type, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + prev_invalid_sql="", + prev_error_msg="", + public_data=True, + shuffle=True, +): + """ + Base prompt generation logic used by all runners. + """ + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) if join_list else "" + pruned_metadata_str = pruned_metadata_ddl + join_str + else: + pruned_metadata_str = table_metadata_string + + return { + "prompt_file": prompt_file, + "question": question, + "db_type": db_type, + "instructions": instructions, + "table_metadata_string": pruned_metadata_str, + "k_shot_prompt": k_shot_prompt, + "glossary": glossary, + "prev_invalid_sql": prev_invalid_sql, + "prev_error_msg": prev_error_msg, + } + + +def extract_sql_from_response(content): + """Extract SQL from between ```sql blocks and format it.""" + try: + generated_query = content.split("```sql", 1)[-1].split("```", 1)[0].strip() + return sqlparse.format( + generated_query, + reindent=True, + keyword_case="upper" + ) + except: + return content + + +def run_eval_in_threadpool(df, model_name, process_row_func, args): + """Common threadpool execution pattern for all runners.""" + total_tried = 0 + total_correct = 0 + output_rows = [] + + print(f"Running evaluation using {model_name}...") + with ThreadPoolExecutor(max_workers=args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append(executor.submit(process_row_func, row, model_name, args)) + + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description(f"Acc: {total_correct}/{total_tried}={total_correct/total_tried:.3f}") + + return output_rows, total_correct, total_tried \ No newline at end of file diff --git a/runners/bedrock_runner.py b/runners/bedrock_runner.py index 806402a..8ef3301 100644 --- a/runners/bedrock_runner.py +++ b/runners/bedrock_runner.py @@ -1,99 +1,100 @@ -import boto3 import json import os -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional - -from eval.eval import compare_query_results +import boto3 +from time import time import pandas as pd + +from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time from utils.reporting import upload_results +from eval.eval import compare_query_results bedrock = boto3.client(service_name="bedrock-runtime") - -def process_row(row, model_id, decimal_points): +def process_row(row, model_id, args): + """Process a single row using AWS Bedrock""" start_time = time() - - body = json.dumps( - { + try: + # Bedrock-specific request payload + body = json.dumps({ "prompt": row["prompt"], "max_gen_len": 600, "temperature": 0, "top_p": 1, - } - ) - - accept = "application/json" - contentType = "application/json" - response = bedrock.invoke_model( - body=body, modelId=model_id, accept=accept, contentType=contentType - ) - model_response = json.loads(response["body"].read()) - - generated_query = model_response["generation"] - end_time = time() + }) - generated_query = ( - generated_query.split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";" - ) - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=decimal_points, + accept = "application/json" + contentType = "application/json" + response = bedrock.invoke_model( + body=body, modelId=model_id, accept=accept, contentType=contentType ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + model_response = json.loads(response["body"].read()) + generated_query = model_response["generation"] + end_time = time() + + # Bedrock-specific SQL extraction + generated_query = extract_sql_from_response(generated_query) + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + row["tokens_used"] = None # Bedrock doesn't provide token count + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = None + return row def run_bedrock_eval(args): - # get params from args + """Run evaluation using AWS Bedrock""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type - decimal_points = args.decimal_points - model_id = args.model cot_table_alias = args.cot_table_alias for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -101,7 +102,8 @@ def run_bedrock_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -114,64 +116,68 @@ def run_bedrock_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append( - executor.submit(process_row, row, model_id, decimal_points) - ) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="bedrock", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/deepseek_runner.py b/runners/deepseek_runner.py index 323c0c1..b6bb838 100644 --- a/runners/deepseek_runner.py +++ b/runners/deepseek_runner.py @@ -1,96 +1,107 @@ import os -from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict - -from eval.eval import compare_query_results +from time import time import pandas as pd + +from openai import OpenAI + +from runners.base_runner import generate_base_prompt, run_eval_in_threadpool from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -from openai import OpenAI from utils.reporting import upload_results - +from eval.eval import compare_query_results client = OpenAI( base_url="https://api.deepseek.com", api_key=os.environ.get("DEEPSEEK_API_KEY") ) - -def process_row(row: Dict, model: str): +def process_row(row: Dict, model: str, args): + """Process a single row using Deepseek""" start_time = time() - messages = row["prompt"] - if model != "deepseek-reasoner": - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - temperature=0.0, - ) - else: - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - ) - content = response.choices[0].message.content - generated_query = content.replace("```sql", "").replace("```", "").strip() - end_time = time() - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + messages = row["prompt"] + # Deepseek-specific handling + if model != "deepseek-reasoner": + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=800, + temperature=0.0, + ) + else: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=800, + ) + content = response.choices[0].message.content + # Deepseek-specific SQL extraction + generated_query = content.replace("```sql", "").replace("```", "").strip() + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + row["tokens_used"] = None # Deepseek doesn't provide token count + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = None + return row def run_deepseek_eval(args): - # get params from args + """Run evaluation using Deepseek""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type - decimal_points = args.decimal_points - model = args.model cot_table_alias = args.cot_table_alias for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): + # Deepseek-specific JSON validation if not prompt_file.endswith(".json"): raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") + print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -98,8 +109,8 @@ def run_deepseek_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question - # note that the prompt for together ai uses the openai chat API + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -112,63 +123,68 @@ def run_deepseek_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, - row["table_aliases"], + row.get("table_aliases", ""), ), axis=1, ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="deepseek", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/gemini_runner.py b/runners/gemini_runner.py index cd292c1..34bc325 100644 --- a/runners/gemini_runner.py +++ b/runners/gemini_runner.py @@ -1,18 +1,13 @@ -import os from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed - +import os import pandas as pd -import sqlparse -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema +from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool from utils.questions import prepare_questions_df -from utils.reporting import upload_results from utils.llm import chat_gemini +from utils.creds import db_creds_all +from utils.reporting import upload_results +from eval.eval import compare_query_results def generate_prompt( @@ -29,55 +24,24 @@ def generate_prompt( public_data=True, shuffle=True, ): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - # raise Exception("Replace this with your private data import") - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup + """Gemini-specific prompt handling""" + # Get base prompt data + base_data = generate_base_prompt( + prompt_file, question, db_name, db_type, instructions, + k_shot_prompt, glossary, table_metadata_string, + prev_invalid_sql, prev_error_msg, public_data, shuffle + ) + # Load and format Gemini text prompt with open(prompt_file, "r") as f: prompt = f.read() - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - + # Format the prompt with all parameters prompt = prompt.format( user_question=question, db_type=db_type, instructions=instructions, - table_metadata_string=pruned_metadata_str, + table_metadata_string=base_data["table_metadata_string"], k_shot_prompt=k_shot_prompt, glossary=glossary, prev_invalid_sql=prev_invalid_sql, @@ -87,68 +51,63 @@ def generate_prompt( def process_row(row, model_name, args): + """Process a single row using Gemini""" start_time = time() + # Prompt already in row from DataFrame preprocessing messages = [{"role": "user", "content": row["prompt"]}] try: response = chat_gemini(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) - try: - generated_query = sqlparse.format( - generated_query, - strip_comments=True, - strip_whitespace=True, - keyword_case="upper", - ) - except: - pass + generated_query = extract_sql_from_response(response.content) + + # Gemini-specific result handling row["generated_query"] = generated_query row["latency_seconds"] = response.time row["tokens_used"] = response.input_tokens + response.output_tokens + + # Verify results with exact_match + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 + row["error_query_gen"] = 1 + row["generated_query"] = "" row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = 0 return row - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - question=question, - query_category=query_category, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row - def run_gemini_eval(args): - # get params from args + """Run evaluation using Gemini""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - model_name = args.model - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type cot_table_alias = args.cot_table_alias @@ -164,6 +123,7 @@ def run_gemini_eval(args): questions_file, db_type, num_questions, k_shot, cot_table_alias ) + # Gemini-specific: preprocess prompts into DataFrame df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -181,50 +141,52 @@ def run_gemini_eval(args): ), axis=1, ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - print(f"Running evaluation using {model_name}...") - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model_name, args)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row.get("correct", 0): - total_correct += 1 - total_tried += 1 - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) + + output_df.to_csv(output_file, index=False, float_format="%.2f") - results = output_df.to_dict("records") + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") - if args.upload_url is not None: - with open(prompt_file, "r") as f: - prompt = f.read() + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() upload_results( - results=results, + results=output_df.to_dict("records"), url=args.upload_url, - runner_type="api_runner", + runner_type="gemini", prompt=prompt, args=args, ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/hf_runner.py b/runners/hf_runner.py index 9046a65..ada7f9d 100644 --- a/runners/hf_runner.py +++ b/runners/hf_runner.py @@ -1,26 +1,24 @@ import os from typing import Optional - -from eval.eval import compare_query_results -import pandas as pd import torch +import gc +import pandas as pd from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, ) +from tqdm import tqdm +from psycopg2.extensions import QueryCanceledError + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from psycopg2.extensions import QueryCanceledError -from time import time -import gc from utils.reporting import upload_results +from eval.eval import compare_query_results device_map = "mps" if torch.backends.mps.is_available() else "auto" - def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]): """ Load a HuggingFace tokenizer and model. @@ -62,15 +60,23 @@ def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]): return tokenizer, model +def extract_hf_sql(text: str, has_sql_tag: bool) -> str: + """HuggingFace-specific SQL extraction""" + if not has_sql_tag: + return text.split("```")[0].split(";")[0].strip() + ";" + else: + return text.split("[/SQL]")[0].split(";")[0].strip() + ";" + + def run_hf_eval(args): - # get params from args + """Run evaluation using HuggingFace models""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_name = args.model adapter_path = args.adapter - output_file_list = args.output_file k_shot = args.k_shot db_type = args.db_type num_beams = args.num_beams @@ -94,8 +100,6 @@ def run_hf_eval(args): print("model loaded\nnow generating and evaluating predictions...") - # from here, we generate and evaluate predictions - # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, batch_size=args.batch_size ) @@ -104,7 +108,6 @@ def run_hf_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -112,7 +115,8 @@ def run_hf_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -125,15 +129,16 @@ def run_hf_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) @@ -165,30 +170,16 @@ def chunk_dataframe(df, chunk_size): top_p=None, ) gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() for row, result in zip(batch.to_dict("records"), generated_queries): total_tried += 1 - # we set return_full_text to False so that we don't get the prompt text in the generated text - # this simplifies our postprocessing to deal with just the truncation of the end of the query - - if "[SQL]" not in row["prompt"]: - generated_query = ( - result[0]["generated_text"] - .split("```")[0] - .split(";")[0] - .strip() - + ";" - ) - else: - generated_query = ( - result[0]["generated_text"] - .split("[/SQL]")[0] - .split(";")[0] - .strip() - + ";" - ) + has_sql_tag = "[SQL]" in row["prompt"] + generated_query = extract_hf_sql( + result[0]["generated_text"], has_sql_tag + ) gc.collect() if torch.cuda.is_available(): @@ -203,8 +194,6 @@ def chunk_dataframe(df, chunk_size): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] try: exact_match, correct = compare_query_results( @@ -212,14 +201,15 @@ def chunk_dataframe(df, chunk_size): query_gen=generated_query, db_name=db_name, db_type=db_type, - db_creds=db_creds, + db_creds=db_creds_all[db_type], question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) + row["is_correct"] = int(correct) # For base runner compatibility row["error_msg"] = "" if correct: total_correct += 1 @@ -236,25 +226,47 @@ def chunk_dataframe(df, chunk_size): f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="hf_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="hf_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/llama_cpp_runner.py b/runners/llama_cpp_runner.py index 0297ca0..c2babc4 100644 --- a/runners/llama_cpp_runner.py +++ b/runners/llama_cpp_runner.py @@ -1,84 +1,93 @@ import os - -from eval.eval import compare_query_results +from time import time import pandas as pd +from llama_cpp import Llama +from tqdm import tqdm + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time from utils.reporting import upload_results -from llama_cpp import Llama +from eval.eval import compare_query_results def process_row(llm, row, args): + """Process a single row using Llama.cpp""" start_time = time() - prompt = row["prompt"] - generated_query = ( - llm( + try: + prompt = row["prompt"] + response = llm( prompt, max_tokens=512, temperature=0, top_p=1, echo=False, repeat_penalty=1.0, - )["choices"][0]["text"] - .split(";")[0] - .split("```")[0] - .strip() - + ";" - ) - end_time = time() - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + # Llama.cpp-specific SQL extraction + generated_query = response["choices"][0]["text"].split(";")[0].split("```")[0].strip() + ";" + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + return row def run_llama_cpp_eval(args): - # get params from args + """Run evaluation using Llama.cpp""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_path = args.model - output_file_list = args.output_file k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias + # Load Llama.cpp model llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096) for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -86,7 +95,8 @@ def run_llama_cpp_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -99,19 +109,21 @@ def run_llama_cpp_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) + # Process rows with direct iteration (no threading) total_tried = 0 total_correct = 0 output_rows = [] @@ -120,7 +132,7 @@ def run_llama_cpp_eval(args): for row in df.to_dict("records"): row = process_row(llm, row, args) output_rows.append(row) - if row["correct"]: + if row.get("correct", 0): total_correct += 1 total_tried += 1 pbar.update(1) @@ -128,28 +140,50 @@ def run_llama_cpp_eval(args): f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="llama_cpp_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="llama_cpp_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/mistral_runner.py b/runners/mistral_runner.py index 4abdf81..3063021 100644 --- a/runners/mistral_runner.py +++ b/runners/mistral_runner.py @@ -1,18 +1,15 @@ import os from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed +import pandas as pd from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage -import pandas as pd -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.gen_prompt import to_prompt_schema -from utils.dialects import convert_postgres_ddl_to_dialect +from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool from utils.questions import prepare_questions_df +from utils.creds import db_creds_all from utils.reporting import upload_results +from eval.eval import compare_query_results api_key = os.environ.get("MISTRAL_API_KEY") client = MistralClient(api_key=api_key) @@ -32,141 +29,113 @@ def generate_prompt( public_data=True, shuffle=True, ): + """Mistral-specific prompt handling with System/User format""" + # Get base prompt data + base_data = generate_base_prompt( + prompt_file, question, db_name, db_type, instructions, + k_shot_prompt, glossary, table_metadata_string, + prev_invalid_sql, prev_error_msg, public_data, shuffle + ) + + # Load and parse Mistral-specific prompt format with open(prompt_file, "r") as f: prompt = f.read() # Check that System and User prompts are in the prompt file if "System:" not in prompt or "User:" not in prompt: raise ValueError("Invalid prompt file. Please use prompt_mistral.md") + sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() user_prompt = prompt.split("User:")[1].strip() - if table_metadata_string == "": - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - md = dbs[db_name]["table_metadata"] - metadata_ddl = to_prompt_schema(md, shuffle) - metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - + # Format the user prompt with parameters user_prompt = user_prompt.format( user_question=question, instructions=instructions, - table_metadata_string=pruned_metadata_str, + table_metadata_string=base_data["table_metadata_string"], k_shot_prompt=k_shot_prompt, glossary=glossary, prev_invalid_sql=prev_invalid_sql, prev_error_msg=prev_error_msg, ) - messages = [ - ChatMessage( - role="system", - content=sys_prompt, - ), - ChatMessage( - role="user", - content=user_prompt, - ), + + # Return Mistral-specific message format + return [ + ChatMessage(role="system", content=sys_prompt), + ChatMessage(role="user", content=user_prompt), ] - return messages def process_row(row, model, args): + """Process a single row using Mistral""" start_time = time() - chat_response = client.chat( - model=model, - messages=row["prompt"], - temperature=0, - max_tokens=600, - ) - end_time = time() - generated_query = chat_response.choices[0].message.content - try: - # replace all backslashes with empty string - generated_query = generated_query.replace("\\", "") - - generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() - generated_query = [i for i in generated_query.split("```") if i.strip() != ""][ - 0 - ] + ";" - except Exception as e: - print(e) + chat_response = client.chat( + model=model, + messages=row["prompt"], + temperature=0, + max_tokens=600, + ) + end_time = time() generated_query = chat_response.choices[0].message.content - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + # Mistral-specific SQL extraction with backslash handling + try: + generated_query = generated_query.replace("\\", "") + generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() + generated_query = [i for i in generated_query.split("```") if i.strip() != ""][0] + ";" + except Exception as e: + print(e) + generated_query = chat_response.choices[0].message.content + + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + return row def run_mistral_eval(args): - # get params from args + """Run evaluation using Mistral""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - model = args.model - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type cot_table_alias = args.cot_table_alias @@ -174,7 +143,6 @@ def run_mistral_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -182,7 +150,8 @@ def run_mistral_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Mistral-specific: preprocess prompts into DataFrame df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -200,49 +169,52 @@ def run_mistral_eval(args): ), axis=1, ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model, args)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row.get("correct", 0): - total_correct += 1 - total_tried += 1 - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + + output_df.to_csv(output_file, index=False, float_format="%.2f") + + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="mistral_runner", - prompt=prompt, - args=args, - ) + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="mistral", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/mlx_runner.py b/runners/mlx_runner.py index e773008..762175b 100644 --- a/runners/mlx_runner.py +++ b/runners/mlx_runner.py @@ -1,78 +1,87 @@ import os - -from eval.eval import compare_query_results +from time import time import pandas as pd +from tqdm import tqdm +from mlx_lm import load, generate + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time from utils.reporting import upload_results -from mlx_lm import load, generate +from eval.eval import compare_query_results def process_row(model, tokenizer, row, args): + """Process a single row using MLX""" start_time = time() - prompt = row["prompt"] - - generated_query = ( - generate(model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True) - .split(";")[0] - .split("```")[0] - .strip() - + ";" - ) - end_time = time() - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + prompt = row["prompt"] + + # MLX-specific generation + generated_text = generate(model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True) + generated_query = generated_text.split(";")[0].split("```")[0].strip() + ";" + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + return row def run_mlx_eval(args): - # get params from args + """Run evaluation using MLX""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_path = args.model - output_file_list = args.output_file k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias + # MLX-specific model loading model, tokenizer = load(model_path) for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -80,7 +89,8 @@ def run_mlx_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -93,19 +103,21 @@ def run_mlx_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) + # Process rows sequentially with tqdm total_tried = 0 total_correct = 0 output_rows = [] @@ -114,7 +126,7 @@ def run_mlx_eval(args): for row in df.to_dict("records"): row = process_row(model, tokenizer, row, args) output_rows.append(row) - if row["correct"]: + if row.get("correct", 0): total_correct += 1 total_tried += 1 pbar.update(1) @@ -122,28 +134,50 @@ def run_mlx_eval(args): f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="mlx_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="mlx_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/openai_runner.py b/runners/openai_runner.py index 5d207ef..b89e54a 100644 --- a/runners/openai_runner.py +++ b/runners/openai_runner.py @@ -1,19 +1,14 @@ -import os from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed import json - +import os import pandas as pd -import sqlparse -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema +from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool from utils.questions import prepare_questions_df -from utils.reporting import upload_results from utils.llm import chat_openai +from utils.creds import db_creds_all +from utils.reporting import upload_results +from eval.eval import compare_query_results def generate_prompt( @@ -30,46 +25,20 @@ def generate_prompt( public_data=True, shuffle=True, ): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup + """OpenAI-specific prompt handling""" + # Get base prompt data + base_data = generate_base_prompt( + prompt_file, question, db_name, db_type, instructions, + k_shot_prompt, glossary, table_metadata_string, + prev_invalid_sql, prev_error_msg, public_data, shuffle + ) + # Load and format OpenAI-specific JSON prompt with open(prompt_file, "r") as f: prompt = json.load(f) - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - + pruned_metadata_str = base_data["table_metadata_string"] + if prompt[0]["role"] == "system": prompt[0]["content"] = prompt[0]["content"].format( db_type=db_type, @@ -81,7 +50,7 @@ def generate_prompt( k_shot_prompt=k_shot_prompt, ) else: - prompt[0]["content"] = prompt[1]["content"].format( + prompt[0]["content"] = prompt[0]["content"].format( db_type=db_type, user_question=question, instructions=instructions, @@ -92,6 +61,7 @@ def generate_prompt( def process_row(row, model_name, args): + """Process a single row using OpenAI""" start_time = time() messages = generate_prompt( prompt_file=args.prompt_file[0], @@ -109,34 +79,55 @@ def process_row(row, model_name, args): ) try: response = chat_openai(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) - try: - generated_query = sqlparse.format( - generated_query, reindent=True, keyword_case="upper" - ) - except: - pass - return { - "query": generated_query, + generated_query = extract_sql_from_response(response.content) + + result = { + "generated_query": generated_query, "reason": "", - "err": "", + "error_msg": "", "latency_seconds": time() - start_time, "tokens_used": response.input_tokens + response.output_tokens, } + + # Verify results + expected_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + try: + is_correct = compare_query_results( + query_gold=expected_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=row["question"], + query_category=row["query_category"], + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + if is_correct: + row["is_correct"] = 1 + else: + row["is_correct"] = 0 + result["error_msg"] = "INCORRECT RESULTS" + except Exception as e: + row["error_db_exec"] = 1 + result["error_msg"] = f"EXECUTION ERROR: {str(e)}" + + # Update row with result data + row.update(result) + return row except Exception as e: - return { - "query": "", - "reason": "", - "err": f"GENERATION ERROR: {str(e)}", - "latency_seconds": time() - start_time, - "tokens_used": 0, - } + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["reason"] = "" + row["error_msg"] = f"GENERATION ERROR: {str(e)}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = 0 + return row def run_openai_eval(args): - # get params from args + """Run evaluation using OpenAI""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file @@ -153,78 +144,21 @@ def run_openai_eval(args): print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" ) - question_query_df = prepare_questions_df( + df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - futures = [] - for row in input_rows: - generated_query_fut = executor.submit( - process_row, - row=row, - model_name=args.model, - args=args, - ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - try: - is_correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - question=row["question"], - query_category=row["query_category"], - db_creds=db_creds_all[db_type], - ) - if is_correct: - total_correct += 1 - row["is_correct"] = 1 - row["error_msg"] = "" - else: - row["is_correct"] = 0 - row["error_msg"] = "INCORRECT RESULTS" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"EXECUTION ERROR: {str(e)}" - output_rows.append(row) - pbar.set_description( - f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" - ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - # save results to csv + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) if "prompt" in output_df.columns: del output_df["prompt"] - # get num rows, mean correct, mean error_db_exec for each query_category + + # Get stats by query category agg_stats = ( output_df.groupby("query_category") .agg( @@ -235,26 +169,30 @@ def run_openai_eval(args): .reset_index() ) print(agg_stats) - # get directory of output_file and create if not exist + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") - # get average rate of correct results - avg_subset = output_df["correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="openai", - prompt=prompt, - args=args, - ) + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="openai", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/together_runner.py b/runners/together_runner.py index 0414e57..826408f 100644 --- a/runners/together_runner.py +++ b/runners/together_runner.py @@ -1,96 +1,106 @@ import os -from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict - -from eval.eval import compare_query_results +from time import time import pandas as pd +from copy import deepcopy + +from together import Together + +from runners.base_runner import generate_base_prompt, run_eval_in_threadpool from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -from together import Together from utils.reporting import upload_results - +from eval.eval import compare_query_results client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) - def process_row(row: Dict, model: str): + """Process a single row using Together""" start_time = time() - if model.startswith("meta-llama"): - stop = ["<|eot_id|>", "<|eom_id|>"] - else: - print( - "Undefined stop token(s). Please specify the stop token(s) for the model." - ) - stop = [] - messages = row["prompt"] - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - temperature=0.0, - stop=stop, - stream=False, - ) - content = response.choices[0].message.content - generated_query = content.split("```", 1)[0].strip() - end_time = time() - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, + # Together-specific stop tokens + if model.startswith("meta-llama"): + stop = ["<|eot_id|>", "<|eom_id|>"] + else: + print("Undefined stop token(s). Please specify the stop token(s) for the model.") + stop = [] + + messages = row["prompt"] + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=800, + temperature=0.0, + stop=stop, + stream=False, ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + # Together-specific SQL extraction + content = response.choices[0].message.content + generated_query = content.split("```", 1)[0].strip() + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + row["tokens_used"] = None # Together doesn't provide token count + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = None + return row def run_together_eval(args): - # get params from args + """Run evaluation using Together""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type - decimal_points = args.decimal_points - model = args.model cot_table_alias = args.cot_table_alias for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): + # Together-specific JSON validation if not prompt_file.endswith(".json"): raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") + print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -98,8 +108,8 @@ def run_together_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question - # note that the prompt for together ai uses the openai chat API + + # Together-specific: use full generate_prompt with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -112,63 +122,68 @@ def run_together_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, - row["table_aliases"], + row.get("table_aliases", ""), ), axis=1, ) + + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="together", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file diff --git a/runners/vllm_runner.py b/runners/vllm_runner.py index 59ed962..ba94207 100644 --- a/runners/vllm_runner.py +++ b/runners/vllm_runner.py @@ -1,40 +1,45 @@ -import json import os -from typing import List import sqlparse +import time +import torch +import pandas as pd +from typing import List +from tqdm import tqdm + from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest -from eval.eval import compare_query_results -import pandas as pd +from transformers import AutoTokenizer + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -import time -import torch -from transformers import AutoTokenizer -from tqdm import tqdm from utils.reporting import upload_results +from eval.eval import compare_query_results def run_vllm_eval(args): - # get params from args + """Run evaluation using VLLM with batching""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_name = args.model - output_file_list = args.output_file num_beams = args.num_beams k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias + + # VLLM-specific LoRA handling enable_lora = True if args.adapter else False lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None - # initialize model only once as it takes a while + # Initialize VLLM model and tokenizer print(f"Preparing {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token_id = tokenizer.eos_token_id + + # VLLM-specific model initialization if not args.quantized: llm = LLM( model=model_name, @@ -66,7 +71,6 @@ def run_vllm_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -74,7 +78,8 @@ def run_vllm_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -87,15 +92,16 @@ def run_vllm_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, 'num_columns') else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) @@ -106,51 +112,44 @@ def chunk_dataframe(df, chunk_size): df_chunks = [] for i in range(0, len(df), chunk_size): df_i = df.iloc[i : min(i + chunk_size, len(df))] - print( - f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions" - ) + print(f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions") df_chunks.append(df_i) return df_chunks + # VLLM-specific batch processing df_chunks = chunk_dataframe(df, args.batch_size) - total_tried = 0 total_correct = 0 output_rows = [] - print(f"Generating completions") - + print("Generating completions") for batch in (pbar := tqdm(df_chunks, total=len(df))): prompts = batch["prompt"].tolist() print(f"Generating completions for {len(prompts)} prompts") + + # VLLM-specific token handling prompt_tokens = [] prompt_token_sizes = [] for prompt in prompts: token_ids = tokenizer.encode(prompt, add_special_tokens=False) - # add bos token if not already present in prompt if token_ids[0] != tokenizer.bos_token_id: token_ids = [tokenizer.bos_token_id] + token_ids prompt_tokens.append(token_ids) prompt_token_sizes.append(len(token_ids)) - print( - f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}" - ) + print(f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}") + start_time = time.time() - # outputs = llm.generate(prompts, sampling_params) # if you prefer to use prompts instead of token_ids outputs = llm.generate( sampling_params=sampling_params, prompt_token_ids=prompt_tokens, use_tqdm=False, lora_request=lora_request, ) - print( - f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds" - ) + print(f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds") time_taken = time.time() - start_time + for row, output in zip(batch.to_dict("records"), outputs): - generated_query = ( - output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" - ) + generated_query = output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" normalized_query = sqlparse.format( generated_query, keyword_case="upper", strip_whitespace=True ) @@ -158,28 +157,29 @@ def chunk_dataframe(df, chunk_size): row["tokens_used"] = len(output.outputs[0].token_ids) row["latency_seconds"] = time_taken / len(batch) + # Verify results golden_query = row["query"] db_name = row["db_name"] db_type = row["db_type"] question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] + try: exact_match, correct = compare_query_results( query_gold=golden_query, query_gen=generated_query, db_name=db_name, db_type=db_type, - db_creds=db_creds, + db_creds=db_creds_all[db_type], question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) + row["is_correct"] = int(correct) # For base runner compatibility row["error_msg"] = "" if correct: total_correct += 1 @@ -189,31 +189,45 @@ def chunk_dataframe(df, chunk_size): total_tried += 1 output_rows.append(row) + pbar.update(len(batch)) - pbar.set_description( - f"Correct so far: {total_correct}/{(total_tried)} ({100*total_correct/(total_tried):.2f}%)" - ) + pbar.set_description(f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)") + + # Process results df = pd.DataFrame(output_rows) - del df["prompt"] - print(df.groupby("query_category")[["exact_match", "correct"]].mean()) + if "prompt" in df.columns: + del df["prompt"] + + # Get stats by query category + agg_stats = df.groupby("query_category")[["exact_match", "correct"]].mean() + print(agg_stats) df = df.sort_values(by=["db_name", "query_category", "question"]) print(f"Average tokens generated: {df['tokens_used'].mean():.1f}") - # get directory of output_file and create if not exist + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + df.to_csv(output_file, index=False, float_format="%.2f") print(f"Saved results to {output_file}") - results = df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="vllm_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, 'upload_url') and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=df.to_dict("records"), + url=args.upload_url, + runner_type="vllm_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") \ No newline at end of file From 47e62abfbc59956e8e9667e7183c92465605a3f2 Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Thu, 30 Jan 2025 02:49:19 +0800 Subject: [PATCH 2/3] linting --- runners/anthropic_runner.py | 8 +++--- runners/api_runner.py | 16 ++++++------ runners/base_runner.py | 18 +++++++------ runners/bedrock_runner.py | 41 ++++++++++++++++++------------ runners/deepseek_runner.py | 23 +++++++++-------- runners/gemini_runner.py | 41 ++++++++++++++++++++---------- runners/hf_runner.py | 19 +++++++++----- runners/llama_cpp_runner.py | 22 +++++++++------- runners/mistral_runner.py | 49 ++++++++++++++++++++++++------------ runners/mlx_runner.py | 24 ++++++++++-------- runners/openai_runner.py | 41 ++++++++++++++++++++---------- runners/together_runner.py | 25 +++++++++++-------- runners/vllm_runner.py | 50 ++++++++++++++++++++++++------------- 13 files changed, 236 insertions(+), 141 deletions(-) diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index f9b35cf..264f2e1 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -148,7 +148,7 @@ def run_anthropic_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -175,7 +175,7 @@ def run_anthropic_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + output_df.to_csv(output_file, index=False, float_format="%.2f") # Print summary stats @@ -185,7 +185,7 @@ def run_anthropic_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -196,4 +196,4 @@ def run_anthropic_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/api_runner.py b/runners/api_runner.py index fb43299..a8b57f1 100644 --- a/runners/api_runner.py +++ b/runners/api_runner.py @@ -216,8 +216,8 @@ def run_api_eval(args): cot_table_alias = args.cot_table_alias db_type = args.db_type logprobs = args.logprobs - run_name = getattr(args, 'run_name', None) - sql_lora_path = getattr(args, 'adapter', None) + run_name = getattr(args, "run_name", None) + sql_lora_path = getattr(args, "adapter", None) if sql_lora_path: print("Using LoRA adapter at:", sql_lora_path) @@ -265,7 +265,7 @@ def run_api_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), @@ -306,7 +306,7 @@ def run_api_eval(args): # Clean up and save results if "prompt" in output_df.columns: del output_df["prompt"] - + output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) @@ -319,14 +319,16 @@ def run_api_eval(args): # Handle run naming and result upload if run_name is None: run_name = output_file.split("/")[-1].replace(".csv", "") - print("Run name not provided. Using output filename for run name:", run_name) + print( + "Run name not provided. Using output filename for run name:", run_name + ) print(f"Total questions: {total_tried}") print(f"Total correct: {total_correct}") print(f"Accuracy: {total_correct/total_tried:.3f}") try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: upload_results( results=output_df.to_dict("records"), url=args.upload_url, @@ -335,4 +337,4 @@ def run_api_eval(args): run_name=run_name, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/base_runner.py b/runners/base_runner.py index 0f535fd..4669ed0 100644 --- a/runners/base_runner.py +++ b/runners/base_runner.py @@ -60,7 +60,11 @@ def generate_base_prompt( join_str = f"{col_1} can be joined with {col_2}" if join_str not in join_list: join_list.append(join_str) - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) if join_list else "" + join_str = ( + "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + if join_list + else "" + ) pruned_metadata_str = pruned_metadata_ddl + join_str else: pruned_metadata_str = table_metadata_string @@ -82,11 +86,7 @@ def extract_sql_from_response(content): """Extract SQL from between ```sql blocks and format it.""" try: generated_query = content.split("```sql", 1)[-1].split("```", 1)[0].strip() - return sqlparse.format( - generated_query, - reindent=True, - keyword_case="upper" - ) + return sqlparse.format(generated_query, reindent=True, keyword_case="upper") except: return content @@ -110,6 +110,8 @@ def run_eval_in_threadpool(df, model_name, process_row_func, args): if row.get("correct", 0): total_correct += 1 total_tried += 1 - pbar.set_description(f"Acc: {total_correct}/{total_tried}={total_correct/total_tried:.3f}") + pbar.set_description( + f"Acc: {total_correct}/{total_tried}={total_correct/total_tried:.3f}" + ) - return output_rows, total_correct, total_tried \ No newline at end of file + return output_rows, total_correct, total_tried diff --git a/runners/bedrock_runner.py b/runners/bedrock_runner.py index 8ef3301..4c46769 100644 --- a/runners/bedrock_runner.py +++ b/runners/bedrock_runner.py @@ -4,7 +4,11 @@ from time import time import pandas as pd -from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all @@ -13,17 +17,20 @@ bedrock = boto3.client(service_name="bedrock-runtime") + def process_row(row, model_id, args): """Process a single row using AWS Bedrock""" start_time = time() try: # Bedrock-specific request payload - body = json.dumps({ - "prompt": row["prompt"], - "max_gen_len": 600, - "temperature": 0, - "top_p": 1, - }) + body = json.dumps( + { + "prompt": row["prompt"], + "max_gen_len": 600, + "temperature": 0, + "top_p": 1, + } + ) accept = "application/json" contentType = "application/json" @@ -41,7 +48,7 @@ def process_row(row, model_id, args): row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time row["tokens_used"] = None # Bedrock doesn't provide token count - + # Verify results golden_query = row["query"] db_name = row["db_name"] @@ -49,7 +56,7 @@ def process_row(row, model_id, args): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -60,7 +67,9 @@ def process_row(row, model_id, args): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -69,7 +78,7 @@ def process_row(row, model_id, args): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -123,13 +132,13 @@ def run_bedrock_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), axis=1, ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -156,7 +165,7 @@ def run_bedrock_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: @@ -169,7 +178,7 @@ def run_bedrock_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -180,4 +189,4 @@ def run_bedrock_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/deepseek_runner.py b/runners/deepseek_runner.py index b6bb838..8fb0172 100644 --- a/runners/deepseek_runner.py +++ b/runners/deepseek_runner.py @@ -16,6 +16,7 @@ base_url="https://api.deepseek.com", api_key=os.environ.get("DEEPSEEK_API_KEY") ) + def process_row(row: Dict, model: str, args): """Process a single row using Deepseek""" start_time = time() @@ -44,7 +45,7 @@ def process_row(row: Dict, model: str, args): row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time row["tokens_used"] = None # Deepseek doesn't provide token count - + # Verify results golden_query = row["query"] db_name = row["db_name"] @@ -52,7 +53,7 @@ def process_row(row: Dict, model: str, args): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -63,7 +64,9 @@ def process_row(row: Dict, model: str, args): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -72,7 +75,7 @@ def process_row(row: Dict, model: str, args): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -100,7 +103,7 @@ def run_deepseek_eval(args): # Deepseek-specific JSON validation if not prompt_file.endswith(".json"): raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") - + print(f"Using prompt file {prompt_file}") print("Preparing questions...") print( @@ -130,13 +133,13 @@ def run_deepseek_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), axis=1, ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -163,7 +166,7 @@ def run_deepseek_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: @@ -176,7 +179,7 @@ def run_deepseek_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -187,4 +190,4 @@ def run_deepseek_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/gemini_runner.py b/runners/gemini_runner.py index 34bc325..db9c923 100644 --- a/runners/gemini_runner.py +++ b/runners/gemini_runner.py @@ -2,7 +2,11 @@ import os import pandas as pd -from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.questions import prepare_questions_df from utils.llm import chat_gemini from utils.creds import db_creds_all @@ -27,9 +31,18 @@ def generate_prompt( """Gemini-specific prompt handling""" # Get base prompt data base_data = generate_base_prompt( - prompt_file, question, db_name, db_type, instructions, - k_shot_prompt, glossary, table_metadata_string, - prev_invalid_sql, prev_error_msg, public_data, shuffle + prompt_file, + question, + db_name, + db_type, + instructions, + k_shot_prompt, + glossary, + table_metadata_string, + prev_invalid_sql, + prev_error_msg, + public_data, + shuffle, ) # Load and format Gemini text prompt @@ -58,19 +71,19 @@ def process_row(row, model_name, args): try: response = chat_gemini(messages=messages, model=model_name, temperature=0.0) generated_query = extract_sql_from_response(response.content) - + # Gemini-specific result handling row["generated_query"] = generated_query row["latency_seconds"] = response.time row["tokens_used"] = response.input_tokens + response.output_tokens - + # Verify results with exact_match golden_query = row["query"] db_name = row["db_name"] db_type = row["db_type"] question = row["question"] query_category = row["query_category"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -80,7 +93,9 @@ def process_row(row, model_name, args): db_creds=db_creds_all[db_type], question=question, query_category=query_category, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -89,7 +104,7 @@ def process_row(row, model_name, args): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -141,7 +156,7 @@ def run_gemini_eval(args): ), axis=1, ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -168,7 +183,7 @@ def run_gemini_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + output_df.to_csv(output_file, index=False, float_format="%.2f") # Print summary stats @@ -178,7 +193,7 @@ def run_gemini_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -189,4 +204,4 @@ def run_gemini_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/hf_runner.py b/runners/hf_runner.py index ada7f9d..1e601e9 100644 --- a/runners/hf_runner.py +++ b/runners/hf_runner.py @@ -19,6 +19,7 @@ device_map = "mps" if torch.backends.mps.is_available() else "auto" + def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]): """ Load a HuggingFace tokenizer and model. @@ -136,7 +137,7 @@ def run_hf_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), @@ -205,11 +206,17 @@ def chunk_dataframe(df, chunk_size): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points + if hasattr(args, "decimal_points") + else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) - row["is_correct"] = int(correct) # For base runner compatibility + row["is_correct"] = int( + correct + ) # For base runner compatibility row["error_msg"] = "" if correct: total_correct += 1 @@ -248,7 +255,7 @@ def chunk_dataframe(df, chunk_size): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + output_df.to_csv(output_file, index=False, float_format="%.2f") # Print summary stats @@ -258,7 +265,7 @@ def chunk_dataframe(df, chunk_size): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -269,4 +276,4 @@ def chunk_dataframe(df, chunk_size): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/llama_cpp_runner.py b/runners/llama_cpp_runner.py index c2babc4..16131a7 100644 --- a/runners/llama_cpp_runner.py +++ b/runners/llama_cpp_runner.py @@ -25,13 +25,15 @@ def process_row(llm, row, args): repeat_penalty=1.0, ) # Llama.cpp-specific SQL extraction - generated_query = response["choices"][0]["text"].split(";")[0].split("```")[0].strip() + ";" + generated_query = ( + response["choices"][0]["text"].split(";")[0].split("```")[0].strip() + ";" + ) end_time = time() # Store results row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time - + # Verify results golden_query = row["query"] db_name = row["db_name"] @@ -39,7 +41,7 @@ def process_row(llm, row, args): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -50,7 +52,9 @@ def process_row(llm, row, args): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -59,7 +63,7 @@ def process_row(llm, row, args): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -116,7 +120,7 @@ def run_llama_cpp_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), @@ -162,7 +166,7 @@ def run_llama_cpp_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: @@ -175,7 +179,7 @@ def run_llama_cpp_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -186,4 +190,4 @@ def run_llama_cpp_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/mistral_runner.py b/runners/mistral_runner.py index 3063021..97b287e 100644 --- a/runners/mistral_runner.py +++ b/runners/mistral_runner.py @@ -5,7 +5,11 @@ from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage -from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.questions import prepare_questions_df from utils.creds import db_creds_all from utils.reporting import upload_results @@ -32,9 +36,18 @@ def generate_prompt( """Mistral-specific prompt handling with System/User format""" # Get base prompt data base_data = generate_base_prompt( - prompt_file, question, db_name, db_type, instructions, - k_shot_prompt, glossary, table_metadata_string, - prev_invalid_sql, prev_error_msg, public_data, shuffle + prompt_file, + question, + db_name, + db_type, + instructions, + k_shot_prompt, + glossary, + table_metadata_string, + prev_invalid_sql, + prev_error_msg, + public_data, + shuffle, ) # Load and parse Mistral-specific prompt format @@ -44,7 +57,7 @@ def generate_prompt( # Check that System and User prompts are in the prompt file if "System:" not in prompt or "User:" not in prompt: raise ValueError("Invalid prompt file. Please use prompt_mistral.md") - + sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() user_prompt = prompt.split("User:")[1].strip() @@ -58,7 +71,7 @@ def generate_prompt( prev_invalid_sql=prev_invalid_sql, prev_error_msg=prev_error_msg, ) - + # Return Mistral-specific message format return [ ChatMessage(role="system", content=sys_prompt), @@ -83,14 +96,16 @@ def process_row(row, model, args): try: generated_query = generated_query.replace("\\", "") generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() - generated_query = [i for i in generated_query.split("```") if i.strip() != ""][0] + ";" + generated_query = [ + i for i in generated_query.split("```") if i.strip() != "" + ][0] + ";" except Exception as e: print(e) generated_query = chat_response.choices[0].message.content - + row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time - + # Verify results golden_query = row["query"] db_name = row["db_name"] @@ -98,7 +113,7 @@ def process_row(row, model, args): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -109,7 +124,9 @@ def process_row(row, model, args): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -118,7 +135,7 @@ def process_row(row, model, args): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -169,7 +186,7 @@ def run_mistral_eval(args): ), axis=1, ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -196,7 +213,7 @@ def run_mistral_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + output_df.to_csv(output_file, index=False, float_format="%.2f") # Print summary stats @@ -206,7 +223,7 @@ def run_mistral_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -217,4 +234,4 @@ def run_mistral_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/mlx_runner.py b/runners/mlx_runner.py index 762175b..0f1601a 100644 --- a/runners/mlx_runner.py +++ b/runners/mlx_runner.py @@ -16,16 +16,18 @@ def process_row(model, tokenizer, row, args): start_time = time() try: prompt = row["prompt"] - + # MLX-specific generation - generated_text = generate(model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True) + generated_text = generate( + model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True + ) generated_query = generated_text.split(";")[0].split("```")[0].strip() + ";" end_time = time() # Store results row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time - + # Verify results golden_query = row["query"] db_name = row["db_name"] @@ -33,7 +35,7 @@ def process_row(model, tokenizer, row, args): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -44,7 +46,9 @@ def process_row(model, tokenizer, row, args): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -53,7 +57,7 @@ def process_row(model, tokenizer, row, args): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -110,7 +114,7 @@ def run_mlx_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), @@ -156,7 +160,7 @@ def run_mlx_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: @@ -169,7 +173,7 @@ def run_mlx_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -180,4 +184,4 @@ def run_mlx_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/openai_runner.py b/runners/openai_runner.py index b89e54a..862613a 100644 --- a/runners/openai_runner.py +++ b/runners/openai_runner.py @@ -3,7 +3,11 @@ import os import pandas as pd -from runners.base_runner import generate_base_prompt, extract_sql_from_response, run_eval_in_threadpool +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.questions import prepare_questions_df from utils.llm import chat_openai from utils.creds import db_creds_all @@ -28,9 +32,18 @@ def generate_prompt( """OpenAI-specific prompt handling""" # Get base prompt data base_data = generate_base_prompt( - prompt_file, question, db_name, db_type, instructions, - k_shot_prompt, glossary, table_metadata_string, - prev_invalid_sql, prev_error_msg, public_data, shuffle + prompt_file, + question, + db_name, + db_type, + instructions, + k_shot_prompt, + glossary, + table_metadata_string, + prev_invalid_sql, + prev_error_msg, + public_data, + shuffle, ) # Load and format OpenAI-specific JSON prompt @@ -38,7 +51,7 @@ def generate_prompt( prompt = json.load(f) pruned_metadata_str = base_data["table_metadata_string"] - + if prompt[0]["role"] == "system": prompt[0]["content"] = prompt[0]["content"].format( db_type=db_type, @@ -80,7 +93,7 @@ def process_row(row, model_name, args): try: response = chat_openai(messages=messages, model=model_name, temperature=0.0) generated_query = extract_sql_from_response(response.content) - + result = { "generated_query": generated_query, "reason": "", @@ -88,7 +101,7 @@ def process_row(row, model_name, args): "latency_seconds": time() - start_time, "tokens_used": response.input_tokens + response.output_tokens, } - + # Verify results expected_query = row["query"] db_name = row["db_name"] @@ -102,7 +115,9 @@ def process_row(row, model_name, args): db_creds=db_creds_all[db_type], question=row["question"], query_category=row["query_category"], - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), ) if is_correct: row["is_correct"] = 1 @@ -112,7 +127,7 @@ def process_row(row, model_name, args): except Exception as e: row["error_db_exec"] = 1 result["error_msg"] = f"EXECUTION ERROR: {str(e)}" - + # Update row with result data row.update(result) return row @@ -147,7 +162,7 @@ def run_openai_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -174,7 +189,7 @@ def run_openai_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + output_df.to_csv(output_file, index=False, float_format="%.2f") # Print summary stats @@ -184,7 +199,7 @@ def run_openai_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -195,4 +210,4 @@ def run_openai_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/together_runner.py b/runners/together_runner.py index 826408f..35a22fb 100644 --- a/runners/together_runner.py +++ b/runners/together_runner.py @@ -15,6 +15,7 @@ client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) + def process_row(row: Dict, model: str): """Process a single row using Together""" start_time = time() @@ -23,9 +24,11 @@ def process_row(row: Dict, model: str): if model.startswith("meta-llama"): stop = ["<|eot_id|>", "<|eom_id|>"] else: - print("Undefined stop token(s). Please specify the stop token(s) for the model.") + print( + "Undefined stop token(s). Please specify the stop token(s) for the model." + ) stop = [] - + messages = row["prompt"] response = client.chat.completions.create( model=model, @@ -44,7 +47,7 @@ def process_row(row: Dict, model: str): row["generated_query"] = generated_query row["latency_seconds"] = end_time - start_time row["tokens_used"] = None # Together doesn't provide token count - + # Verify results golden_query = row["query"] db_name = row["db_name"] @@ -52,7 +55,7 @@ def process_row(row: Dict, model: str): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -71,7 +74,7 @@ def process_row(row: Dict, model: str): except Exception as e: row["error_db_exec"] = 1 row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - + return row except Exception as e: row["error_query_gen"] = 1 @@ -99,7 +102,7 @@ def run_together_eval(args): # Together-specific JSON validation if not prompt_file.endswith(".json"): raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") - + print(f"Using prompt file {prompt_file}") print("Preparing questions...") print( @@ -129,13 +132,13 @@ def run_together_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), axis=1, ) - + output_rows, total_correct, total_tried = run_eval_in_threadpool( df, args.model, process_row, args ) @@ -162,7 +165,7 @@ def run_together_eval(args): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: @@ -175,7 +178,7 @@ def run_together_eval(args): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -186,4 +189,4 @@ def run_together_eval(args): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") diff --git a/runners/vllm_runner.py b/runners/vllm_runner.py index ba94207..57bd960 100644 --- a/runners/vllm_runner.py +++ b/runners/vllm_runner.py @@ -29,7 +29,7 @@ def run_vllm_eval(args): k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias - + # VLLM-specific LoRA handling enable_lora = True if args.adapter else False lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None @@ -38,7 +38,7 @@ def run_vllm_eval(args): print(f"Preparing {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token_id = tokenizer.eos_token_id - + # VLLM-specific model initialization if not args.quantized: llm = LLM( @@ -99,7 +99,7 @@ def run_vllm_eval(args): row.get("cot_instructions", ""), row.get("cot_pregen", False), public_data, - args.num_columns if hasattr(args, 'num_columns') else 40, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, row.get("table_aliases", ""), ), @@ -112,7 +112,9 @@ def chunk_dataframe(df, chunk_size): df_chunks = [] for i in range(0, len(df), chunk_size): df_i = df.iloc[i : min(i + chunk_size, len(df))] - print(f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions") + print( + f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions" + ) df_chunks.append(df_i) return df_chunks @@ -126,7 +128,7 @@ def chunk_dataframe(df, chunk_size): for batch in (pbar := tqdm(df_chunks, total=len(df))): prompts = batch["prompt"].tolist() print(f"Generating completions for {len(prompts)} prompts") - + # VLLM-specific token handling prompt_tokens = [] prompt_token_sizes = [] @@ -136,8 +138,10 @@ def chunk_dataframe(df, chunk_size): token_ids = [tokenizer.bos_token_id] + token_ids prompt_tokens.append(token_ids) prompt_token_sizes.append(len(token_ids)) - print(f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}") - + print( + f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}" + ) + start_time = time.time() outputs = llm.generate( sampling_params=sampling_params, @@ -145,11 +149,15 @@ def chunk_dataframe(df, chunk_size): use_tqdm=False, lora_request=lora_request, ) - print(f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds") + print( + f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds" + ) time_taken = time.time() - start_time - + for row, output in zip(batch.to_dict("records"), outputs): - generated_query = output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" + generated_query = ( + output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" + ) normalized_query = sqlparse.format( generated_query, keyword_case="upper", strip_whitespace=True ) @@ -164,7 +172,7 @@ def chunk_dataframe(df, chunk_size): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - + try: exact_match, correct = compare_query_results( query_gold=golden_query, @@ -175,7 +183,11 @@ def chunk_dataframe(df, chunk_size): question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + decimal_points=( + args.decimal_points + if hasattr(args, "decimal_points") + else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) @@ -189,15 +201,17 @@ def chunk_dataframe(df, chunk_size): total_tried += 1 output_rows.append(row) - + pbar.update(len(batch)) - pbar.set_description(f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)") + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) # Process results df = pd.DataFrame(output_rows) if "prompt" in df.columns: del df["prompt"] - + # Get stats by query category agg_stats = df.groupby("query_category")[["exact_match", "correct"]].mean() print(agg_stats) @@ -208,7 +222,7 @@ def chunk_dataframe(df, chunk_size): output_dir = os.path.dirname(output_file) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - + df.to_csv(output_file, index=False, float_format="%.2f") print(f"Saved results to {output_file}") @@ -219,7 +233,7 @@ def chunk_dataframe(df, chunk_size): # Upload results if URL provided try: - if hasattr(args, 'upload_url') and args.upload_url: + if hasattr(args, "upload_url") and args.upload_url: with open(prompt_file, "r") as f: prompt = f.read() upload_results( @@ -230,4 +244,4 @@ def chunk_dataframe(df, chunk_size): args=args, ) except Exception as e: - print(f"Error uploading results: {e}") \ No newline at end of file + print(f"Error uploading results: {e}") From 96bf3317ffba2e264cfe27c2512e94b54bf43e7b Mon Sep 17 00:00:00 2001 From: codestory Date: Wed, 29 Jan 2025 19:01:22 +0000 Subject: [PATCH 3/3] feat: sync local changes --- runners/anthropic_runner.py | 151 +++++++++++++++++++++++------------- utils/questions.py | 30 +++---- 2 files changed, 113 insertions(+), 68 deletions(-) diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 264f2e1..77b509b 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -1,11 +1,12 @@ -import os from time import time from concurrent.futures import ThreadPoolExecutor, as_completed +import os import pandas as pd import sqlparse from tqdm import tqdm +from runners.base_runner import run_eval_in_threadpool from eval.eval import compare_query_results from utils.creds import db_creds_all from utils.dialects import convert_postgres_ddl_to_dialect @@ -29,61 +30,79 @@ def generate_prompt( public_data=True, shuffle=True, ): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(prompt_file, "r") as f: - prompt = f.read() - - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair + try: + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + with open(prompt_file, "r") as f: + prompt = f.read() + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] join_str = f"{col_1} can be joined with {col_2}" if join_str not in join_list: join_list.append(join_str) + if len(join_list) > 0: + join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) else: - col_1, col_2 = values[0] - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + join_str = "" + pruned_metadata_str = pruned_metadata_ddl + join_str else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - - prompt = prompt.format( - user_question=question, - db_type=db_type, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - ) - return prompt + pruned_metadata_str = table_metadata_string + + prompt = prompt.format( + user_question=question, + db_type=db_type, + instructions=instructions, + table_metadata_string=pruned_metadata_str, + k_shot_prompt=k_shot_prompt, + glossary=glossary, + prev_invalid_sql=prev_invalid_sql, + prev_error_msg=prev_error_msg, + ) + return prompt + except ImportError: + # When defog_data is not available, just format with the existing table_metadata_string + with open(prompt_file, "r") as f: + prompt = f.read() + + prompt = prompt.format( + user_question=question, + db_type=db_type, + instructions=instructions, + table_metadata_string=table_metadata_string, + k_shot_prompt=k_shot_prompt, + glossary=glossary, + prev_invalid_sql=prev_invalid_sql, + prev_error_msg=prev_error_msg, + ) + return prompt def process_row(row, model_name, args): start_time = time() + result_row = row.copy() # Create a copy of the original row to maintain all data prompt = generate_prompt( prompt_file=args.prompt_file[0], question=row["question"], @@ -110,21 +129,43 @@ def process_row(row, model_name, args): ) except: pass - return { - "query": generated_query, + result_row.update({ + "generated_query": generated_query, "reason": "", - "err": "", + "error_msg": "", "latency_seconds": time() - start_time, "tokens_used": response.input_tokens + response.output_tokens, - } + }) + + # Verify the generated query + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=args.db_type, + db_creds=db_creds_all[args.db_type], + question=row["question"], + query_category=row["query_category"], + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + result_row["exact_match"] = int(exact_match) + result_row["correct"] = int(correct) + result_row["is_correct"] = int(correct) + except Exception as e: + result_row["error_db_exec"] = 1 + result_row["error_msg"] = f"EXECUTION ERROR: {str(e)}" + result_row["is_correct"] = 0 except Exception as e: - return { - "query": "", + result_row.update({ + "generated_query": "", "reason": "", - "err": f"GENERATION ERROR: {str(e)}", + "error_msg": f"GENERATION ERROR: {str(e)}", "latency_seconds": time() - start_time, "tokens_used": 0, - } + "is_correct": 0, + }) + return result_row def run_anthropic_eval(args): diff --git a/utils/questions.py b/utils/questions.py index 414a89d..bc89354 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -1,19 +1,23 @@ -from typing import Optional import pandas as pd +from typing import Optional def get_table_aliases(db_name: str) -> str: - from defog_data.metadata import dbs - from utils.aliases import generate_aliases - - metadata = dbs[db_name]["table_metadata"] - table_names = list(metadata.keys()) - aliases = generate_aliases(table_names) - aliases_instruction = ( - "Use the following table aliases when referencing tables in the query:\n" - + aliases - ) - return aliases_instruction + try: + from defog_data.metadata import dbs + from utils.aliases import generate_aliases + + metadata = dbs[db_name]["table_metadata"] + table_names = list(metadata.keys()) + aliases = generate_aliases(table_names) + aliases_instruction = ( + "Use the following table aliases when referencing tables in the query:\n" + + aliases + ) + return aliases_instruction + except ImportError: + # Return empty string when defog_data is not available + return "" def prepare_questions_df( @@ -142,4 +146,4 @@ def prepare_questions_df( elif cot_table_alias == "pregen": question_query_df["cot_pregen"] = True - return question_query_df + return question_query_df \ No newline at end of file