diff --git a/.env b/.env new file mode 100644 index 0000000..0605d28 --- /dev/null +++ b/.env @@ -0,0 +1,2 @@ +# Store API keys here +API_KEYS=apikey1,apikey2,apikey3 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8d970b7..7bcaa84 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,9 @@ __pycache__/ checkpoints/ results/ wandb/ +outputs/ -*.log \ No newline at end of file +*.log +*.egg-info +.vscode +build/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9366673 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black + language_version: python3.10 \ No newline at end of file diff --git a/README.md b/README.md index 975fca7..9f05343 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Its main features are as follows: -## Prerequisites +## TODO: Prerequisites - Your emails, exported to `mbox` format (see tutorial below). - A computer, preferably with a NVIDIA GPU with at least 24 GiB of memory (alternatively, check out [running in Google Colab](#cloud-try-out-panza-in-google-colab)). - A Hugging Face [account](https://huggingface.co/login) to download the models (free of charge). @@ -62,30 +62,21 @@ The overall structure of Panza is as follows: ### Conda 1. Make sure you have a version of [conda](https://docs.anaconda.com/free/miniconda/miniconda-install/) installed. -2. Run `source prepare_env.sh`. This script will create a conda environment named `panza` and install the required packages. - -### Docker -As an alternative to the conda option above, you can run the following commands to pull a docker image with all the dependencies installed. -``` -docker pull istdaslab/panzamail -``` - -or alternatively, you can build the image yourself: -``` -docker build . -f Dockerfile -t istdaslab/panzamail -``` - -Then run it with: -``` -docker run -it --gpus all istdaslab/panzamail /bin/bash +2. Create a new conda environment named 'panza' (or something else) and activate it: +``` bash +conda create -n panza python=3.10 -y +conda activate panza ``` - -In the docker you can activate the `panza` environment with: +3. Install the required packages: +``` bash +pip install . ``` -micromamba activate panza +4. If you want to also finetune models using Panza, you will need to install the additional packages: +``` bash +pip install .[training] ``` -## :rocket: Getting started +## TODO: :rocket: Getting started To quickly get started with building your own personalized email assistant, follow the steps bellow: @@ -118,16 +109,26 @@ At the end of this step you should have the downloaded emails placed inside `dat ### Step 1: Environment configuration -Panza is configured through a set of environment variables defined in `scripts/config.sh` and shared along all running scripts. +Panza is configured through a set of yaml configurations defined in `configs/`. There is a single high-level config under `configs/base.yaml`, and the rest are organized under the main functionalities of the code. +Note that these task-specific configs can, in some cases, be used to override base configs. + Specific use cases, such as hyperparameter tuning, are covered in more detail in `scripts/README.md`. (TODO jen: write this up.) - -The LLM prompt is controlled by a set of `prompt_preambles` that give the model more insight about its role, the user and how to reuse existing emails for *Retrieval-Augmented Generation (RAG)*. See more details in the [prompting section](prompt_preambles/README.md). +1. Data preparation: `configs/data_preparation.yaml`. Additionally, a custom user config must be added under `config/users/` (see below). +1. Finetuning: the main config is in `configs/panza_finetuning.yaml` and the method-specific ones are in `configs/finetuning/` +1. Serving: Serving consists of two parts - a serving infrastructure (that we call 'writer') that runs the LLM and so converts prompts to Panza outputs, and an `interface`, which presents the outputs in a useful form - through a command-line interface, a web interface, a gmail client (TODO:Sean), or in a bulk `.json` format (useful for evaluation). The configs for serving are in `panza_writer.yaml`, and for the interfaces, under `configs/interfaces`. + +These scripts are described in more detail in `scripts/README.md`, but a few customizations need to happen immediately. :warning: Before continuing, make sure you complete the following setup: - - Modifiy the environment variable `PANZA_EMAIL_ADDRESS` inside `scripts/config.sh` with your own email address. - - Modifiy `prompt_preambles/user_preamble.txt` with your own information. If you choose, this can even be empty. +- Copy `users/default.yaml` to `users/[YOURNAME].yaml`. If this is skipped, perform the following modifications on `users/default.yaml` directly. A useful tip for choosing the name of `[YOURNAME]` is to set it to the output of `whoami`. If you modify the default yaml, you will need specify `user=default` as an extra flag in the succeeding steps. +- In the user config, set the email address and username. The email address should be the sender address in the exported emails. (Panza uses this to edit out responses and other emails sent by a different author in the `.mbox` dump.). The username does not have to link to the email itself - it is simply used as a name for the various data files that will come out of the data preparation process. A handy way to set this is if you set it to be the output of the `whoami` call in your shell. +- Modify the personal prompt in `prompt_preambles/user_preamble.txt` to include some basic information about yourself that Panza can use to customize your emails with your correct full name, address, phone number, etc. + + +Additionally, please perform the following login steps to be able to download the base model. - Login to Hugging Face to be able to download pretrained models: `huggingface-cli login`. - - [Optional] Login to Weights & Biases to log metrics during training: `wandb login`. Then, set `PANZA_WANDB_DISABLED=False` in `scripts/config.sh`. + - [Optional] Login to Weights & Biases to log metrics during training: `wandb login`. Then, set `wandb_disabled=false` in `configs/finetuning/base.yaml`. + You are now ready to move to `scripts`. ``` bash @@ -137,62 +138,73 @@ cd scripts ### Step 2: Extract emails -1. Run `./extract_emails.sh`. This extracts your emails in text format to `data/_clean.jsonl` which you can manually inspect. - -2. If you wish to eliminate any emails from the training set (e.g. containing certain personal information), you can simply remove the corresponding rows. - -### Step 3: Prepare dataset - - -1. Simply run `./prepare_dataset.sh`.
+1. Run `CUDA_VISIBLE_DEVICES=X ./prepare_data.sh`.
This scripts takes care of all the prerequisites before training (expand for details). + - Extracts your emails in text format to `data/_clean.jsonl` which you can manually inspect. - Creates synthetic prompts for your emails as described in the [data playback](#film_projector-step-1-data-playback) section. The results are stored in `data/_clean_summarized.jsonl` and you can inspect the `"summary"` field. - Splits data into training and test subsets. See `data/train.jsonl` and `data/test.jsonl`. - Creates a vector database from the embeddings of the training emails which will later be used for *Retrieval-Augmented Generation (RAG)*. See `data/.pkl` and `data/.faiss`.
+**NB**: if you did not change the default configuration in `user/default.yaml` to reflect your particulars but rather created a new file, you need to add the additional flag to the above command where you specify `user=x` where your config file was named `x.yaml`. + +
+ FAQs. + When running the above script, you may encounter an OutOfMemoryError. If this is the case, you can either: +
    +
  1. Reduce the batch size for the data processing step. This can be found in configs/panza_preparation.yaml. +
  2. Move to a machine that has more memory. +
+
+ -### Step 4: Train a LLM on your emails - +### Step 3: Train a LLM on your emails + We currently support `LLaMA3-8B-Instruct` and `Mistral-Instruct-v0.2` LLMs as base models; the former is the default, but we obtained good results with either model. 1. [Recommended] For parameter efficient fine-tuning, run `./train_rosa.sh`. If a larger GPU is available and full-parameter fine-tuning is possible, run `./train_fft.sh`. -2. We have prepopulated the training scripts with parameter values that worked best for us. We recommend you try those first, but you can also experiment with different hyper-parameters by passing extra arguments to the training script, such as `LR`, `LORA_LR`, `NUM_EPOCHS`. All the trained models are saved in the `checkpoints` directory. +2. We have prepopulated the training configs with parameter values that worked best for us. We recommend you try those first, but you can also experiment with different hyper-parameters by passing extra arguments to the training script, such as `lr`, `lora_lr`, `num_epochs`. All the trained models are saved in the `checkpoints` directory. Examples: ``` bash -./train_rosa.sh # Will use the default parameters. +CUDA_VISIBLE_DEVICES=X ./train_rosa.sh # Will use the default parameters. -./train_rosa.sh LR=1e-6 LORA_LR=1e-6 NUM_EPOCHS=7 # Will override LR, LORA_LR, and NUM_EPOCHS. +CUDA_VISIBLE_DEVICES=X ./train_rosa.sh finetuning.lr=1e-6 finetuning.rosa_lr=1e-6 finetuning.max_duration=7ep ``` +
+ FAQs. + The bash scripts that are used to execute the finetuning procedure assume by default that your username is what is returned by the whoami command. This is used to locate the name of the user configs inside the configs/user directory as above. If you directly modified default.yaml, or created another yaml file where the name of that file does not match with the output of whoami, there will be an error. This is an easy fix. You can either: +
    +
  1. Change the name of the yaml file to be the output of whoami. +
  2. You can override the username manually when you launch the bash script by adding user=x where x is the name of the yaml file you created. For example: ./train_rosa.sh user=alonso +
+
+ If you wish to add CUDA_VISIBLE_DEVICES to specify a specific GPU, please add this in the shell script directly by export CUDA_VISIBLE_DEVICES=x where x is the ID of the GPU you wish to use. +

+ A known issue is that when you fine-tune your model with RAG, there can be a case when the tokenization of the dataset seemingly hangs. This is due to a known bug with with HF's map function where n_proc>1. To alleviate this issue, you can set torch.set_num_threads(1) in src/panza/finetuning/train.py or set the equivalent parameter in configs/finetuning/rosa.yaml. +
-### Step 5: Launch Panza! - - -1. Run `./run_panza_gui.sh MODEL=` to serve the trained model in a friendly GUI. -Alternatively, if you prefer using the CLI to interact with Panza, run `./run_panza_cli.sh` instead. -You can experiment with the following arguments: -- If `MODEL` is not specified, it will use a pretrained `Meta-Llama-3-8B-Instruct` model by default, although Panza also works with `Mistral-7B-Instruct-v2`. Try it out to compare the syle difference! -- To disable RAG, run with `PANZA_DISABLE_RAG_INFERENCE=1`. +On a smaller GPU, it may be necessary to further train in lower precision (QRoSA). This can be run as follows: -Example: ``` bash -./run_panza_gui.sh \ - MODEL=/local/path/to/this/repo/checkpoints/models/panza-rosa_1e-6-seed42_7908 \ - PANZA_DISABLE_RAG_INFERENCE=0 # this is the default behaviour, so you can omit it +./train_rosa.sh finetuning.precision=amp_bf16 finetuning.model.weight_bias_dtype=4bit ``` -:email: **Have fun with your new email writing assistant!** :email: - +### Step 5: Launch Panza! + -## :cloud: Try out Panza in Google Colab +- To run Panza after a full training run, run a command like `CUDA_VISIBLE_DEVICES=0 ./runner.sh user=USERNAME interfaces=cli writer/llm=transformers model=latest`. +- To run Panza after a RoSA or LoRA training run, replace `writer/llm=transformers` with `writer/llm=peft` -- You can run Panza in a Google Colab instance [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/IST-DASLab/PanzaMail/blob/main/notebooks/panza_colab.ipynb). + +:email: **Have fun with your new email writing assistant!** :email: + + ## :microscope: Advanced usage @@ -200,11 +212,28 @@ Example: - [Hyper-Parameter Tuning Guide](./scripts/README.md#hyper-parameter-tuning-guide) - [Prompt Preambles Tutorial](prompt_preambles/README.md) +## :woman_technologist: Contributing +If you liked our work and want to contribute to improve the system, please feel free to do so! Make a _fork_ of our repository and once you have made your changes, submit a pull request so that we can review! + +One thing to mention: we want to make sure that we all adhere to the same coding standards, so we have added Black, a code formatter, as a prehook. To ensure that all your files are formatted with Black, do the following: + +1. Install the necessary dependencies +``` +pip install .[contributing] +``` + +2. Run the precommit command +``` +pre-commit install +``` + +3. Continue adding code as usual. All your code will be formatted by Black before commiting! + ## Authors Panza was conceived by Nir Shavit and Dan Alistarh and built by the [Distributed Algorithms and Systems group](https://ist.ac.at/en/research/alistarh-group/) at IST Austria. The contributors are (in alphabetical order): -Dan Alistarh, Eugenia Iofinova, Eldar Kurtic, Ilya Markov, Armand Nicolicioiu, Mahdi Nikdan, Andrei Panferov, and Nir Shavit. +Dan Alistarh, Eugenia Iofinova, Andrej Jovanovic, Eldar Kurtic, Ilya Markov, Armand Nicolicioiu, Mahdi Nikdan, Andrei Panferov, Nir Shavit, and Sean Yang. Contact: dan.alistarh@ist.ac.at diff --git a/README_old.md b/README_old.md new file mode 100644 index 0000000..c2c1d09 --- /dev/null +++ b/README_old.md @@ -0,0 +1,234 @@ +
+ panza demo +
+ +# Panza: A personal email assistant, trained and running on-device + + + +## What is Panza? + + + + +Panza is an automated email assistant customized to your writing style and past email history. \ +Its main features are as follows: +* Panza produces a fine-tuned LLM that matches your writing style, pairing it with a Retrieval-Augmented Generation (RAG) component which helps it produce relevant emails. +* Panza **can be trained and run entirely locally**. Currently, it requires a single GPU with +16-24 GiB of memory, but we also plan to release a CPU-only version. **At no point in training or execution is your data shared with the entities that trained the original LLMs, with LLM distribution services such as Huggingface, or with us.** +* Training and execution are also quick - for a dataset on the order of 1000 emails, training Panza takes well under an hour, and generating a new email takes a few seconds at most. + +
+ panza logo +
+ + +## TODO: Prerequisites +- Your emails, exported to `mbox` format (see tutorial below). +- A computer, preferably with a NVIDIA GPU with at least 24 GiB of memory (alternatively, check out [running in Google Colab](#cloud-try-out-panza-in-google-colab)). +- A Hugging Face [account](https://huggingface.co/login) to download the models (free of charge). +- [Optional] A Weights & Biases [account](https://wandb.ai/login) to log metrics during training (free of charge). +- Basic Python and Unix knowledge, such as building environments and running python scripts. +- *No prior LLMs experience is needed*. + + +## How it works + +### :film_projector: Step 1: Data playback + +For most email clients, it is possible to download a user's past emails in a machine-friendly .mbox format. For example, GMail allows you to do this via [Google Takeout](https://takeout.google.com), whereas Thunderbird allows one to do this via various plugins. + +One key part of Panza is a dataset-generation technique we call **data playback**: Given some of your past emails in .mbox format, we automatically create a training set for Panza by using a pretrained LLM to summarize the emails in instruction form; each email becomes a `(synthetic instruction, real email)` pair. +Given a dataset consisting of all pairs, we use these pairs to "play back" your sent emails: the LLM receives only the instruction, and has to generate the "ground truth" email as a training target. + +We find that this approach is very useful for the LLM to "learn" the user's writing style. + + +### :weight_lifting: Step 2: Local Fine-Tuning via Robust Adaptation (RoSA) + +We then use parameter-efficient finetuning to train the LLM on this dataset, locally. We found that we get the best results with the [RoSA method](https://arxiv.org/pdf/2401.04679.pdf), which combines low-rank (LoRA) and sparse finetuning. If parameter efficiency is not a concern, that is, you have a more powerful GPU, then regular, full-rank/full-parameter finetuning can also be used. We find that a moderate amount of further training strikes the right balance between matching the writer's style without memorizing irrelevant details in past emails. + + +### :owl: Step 3: Serving via RAG + +Once we have a custom user model, Panza can be run locally together with a Retrieval-Augmented Generation (RAG) module. Specifically, this functionality stores past emails in a database and provides a few relevant emails as context for each new query. This allows Panza to better insert specific details, such as a writer's contact information or frequently used Zoom links. + +The overall structure of Panza is as follows: +
+ panza logo +
+ +## Installation + +### Conda +1. Make sure you have a version of [conda](https://docs.anaconda.com/free/miniconda/miniconda-install/) installed. +2. Create a new conda environment named 'panza' (or something else) and activate it: +``` bash +conda create -n panza python=3.10 -y +conda activate panza +``` +3. Install the required packages: +``` bash +pip install . +``` +4. If you want to also finetune models using Panza, you will need to install the additional packages: +``` bash +pip install .[training] +``` + +## TODO: :rocket: Getting started + +To quickly get started with building your own personalized email assistant, follow the steps bellow: + + + + +### Step 0: Download your sent emails + +
+ Expand for detailed download instructions. + + We provide a description for doing this for GMail via Google Takeout. + + 1. Go to [https://takeout.google.com/](https://takeout.google.com/). + 2. Click `Deselect all`. + 3. Find `Mail` section (search for the phrase `Messages and attachments in your Gmail account in MBOX format`). + 4. Select it. + 5. Click on `All Mail data included` and deselect everything except `Sent`. + 6. Scroll to the bottom of the page and click `Next step`. + 7. Click on `Create export`. + 8. Wait for download link to arrive in your inbox. + 9. Download `Sent.mbox` and place it in the `data/` directory. + + For Outlook accounts, we suggest doing this via a Thunderbird plugin for exporting a subset of your email as an MBOX format, such as [this add-on](https://addons.thunderbird.net/en-us/thunderbird/addon/importexporttools-ng/). +
+ +At the end of this step you should have the downloaded emails placed inside `data/Sent.mbox`. + + +### Step 1: Environment configuration + + +Panza is configured through a set of yaml configurations defined in `configs/`. There is a single high-level config under `configs/base.yaml`, and the rest are organized under the main functionalities of the code. +Note that these task-specific configs can, in some cases, be used to override base configs. + Specific use cases, such as hyperparameter tuning, are covered in more detail in `scripts/README.md`. (TODO jen: write this up.) + +1. Data preparation: `configs/data_preparation.yaml`. Additionally, a custom user config must be added under `config/users/` (see below). +1. Finetuning: the main config is in `configs/panza_finetuning.yaml` and the method-specific ones are in `configs/finetuning/` +1. Serving: Serving consists of two parts - a serving infrastructure (that we call 'writer') that runs the LLM and so converts prompts to Panza outputs, and an `interface`, which presents the outputs in a useful form - through a command-line interface, a web interface, a gmail client (TODO:Sean), or in a bulk `.json` format (useful for evaluation). The configs for serving are in `panza_writer.yaml`, and for the interfaces, under `configs/interfaces`. + + +These scripts are described in more detail in `scripts/README.md`, but a few customizations need to happen immediately. +:warning: Before continuing, make sure you complete the following setup: +- Copy `users/default.yaml` to `users/[YOURNAME].yaml`. If this is skipped, perform the following modifications on `users/default.yaml` directly. A useful tip for choosing the name of `[YOURNAME]` is to set it to the output of `whoami`. If you modify the default yaml, you will need specify `user=default` as an extra flag in the succeeding steps. +- In the user config, set the email address and username. The email address should be the sender address in the exported emails. (Panza uses this to edit out responses and other emails sent by a different author in the `.mbox` dump.). The username does not have to link to the email itself - it is simply used as a name for the various data files that will come out of the data preparation process. A handy way to set this is if you set it to be the output of the `whoami` call in your shell. +- Modify the personal prompt in `prompt_preambles/user_preamble.txt` to include some basic information about yourself that Panza can use to customize your emails with your correct full name, address, phone number, etc. + + +Additionally, please perform the following login steps to be able to download the base model. + - Login to Hugging Face to be able to download pretrained models: `huggingface-cli login`. + - [Optional] Login to Weights & Biases to log metrics during training: `wandb login`. Then, set `wandb_disabled=false` in `configs/finetuning/base.yaml`. + + +You are now ready to move to `scripts`. +``` bash +cd scripts +``` + +### Step 2: Extract emails + + +1. Run `CUDA_VISIBLE_DEVICES=X python ./prepare_data.py`.
+ This scripts takes care of all the prerequisites before training (expand for details). + + - Extracts your emails in text format to `data/_clean.jsonl` which you can manually inspect. + - Creates synthetic prompts for your emails as described in the [data playback](#film_projector-step-1-data-playback) section. The results are stored in `data/_clean_summarized.jsonl` and you can inspect the `"summary"` field. + - Splits data into training and test subsets. See `data/train.jsonl` and `data/test.jsonl`. + - Creates a vector database from the embeddings of the training emails which will later be used for *Retrieval-Augmented Generation (RAG)*. See `data/.pkl` and `data/.faiss`. +
+**NB**: if you did not change the default configuration in `user/default.yaml` to reflect your particulars but rather created a new file, you need to add the additional flag to the above command where you specify `user=x` where your config file was named `x.yaml`. + +
+ FAQs. + When running the above script, you may encounter an OutOfMemoryError. If this is the case, you can either: +
    +
  1. Reduce the batch size for the data processing step. This can be found in configs/panza_preparation.yaml. +
  2. Move to a machine that has more memory. +
+
+ +ODO Jen: This doesn't work anymore, because we make the RAG database right away. If you wish to eliminate any emails from the training set (e.g. containing certain personal information), you can simply remove the corresponding rows. + +### Step 3: Train a LLM on your emails + + +We currently support `LLaMA3-8B-Instruct` and `Mistral-Instruct-v0.2` LLMs as base models; the former is the default, but we obtained good results with either model. + +1. [Recommended] For parameter efficient fine-tuning, run `./train_rosa.sh`. +If a larger GPU is available and full-parameter fine-tuning is possible, run `./train_fft.sh`. + +2. We have prepopulated the training configs with parameter values that worked best for us. We recommend you try those first, but you can also experiment with different hyper-parameters by passing extra arguments to the training script, such as `lr`, `lora_lr`, `num_epochs`. All the trained models are saved in the `checkpoints` directory. + +Examples: +``` bash +CUDA_VISIBLE_DEVICES=X ./train_rosa.sh # Will use the default parameters. + +CUDA_VISIBLE_DEVICES=X ./train_rosa.sh finetuning.lr=1e-6 finetuning.rosa_lr=1e-6 finetuning.max_duration=7ep. +``` +
+ FAQs. + The bash scripts that are used to execute the finetuning procedure assume by default that your username is what is returned by the whoami command. This is used to locate the name of the user configs inside the configs/user directory as above. If you directly modified default.yaml, or created another yaml file where the name of that file does not match with the output of whoami, there will be an error. This is an easy fix. You can either: +
    +
  1. Change the name of the yaml file to be the output of whoami. +
  2. You can override the username manually when you launch the bash script by adding user=x where x is the name of the yaml file you created. For example: ./train_rosa.sh user=alonso +
+
+ If you wish to add CUDA_VISIBLE_DEVICES to specify a specific GPU, please add this in the shell script directly by export CUDA_VISIBLE_DEVICES=x where x is the ID of the GPU you wish to use. +

+ A known issue is that when you fine-tune your model with RAG, there can be a case when the tokenization of the dataset seemingly hangs. This is due to a known bug with with HF's map function where n_proc>1. To alleviate this issue, you can set torch.set_num_threads(1) in src/panza3/finetuning/train.py or set the equivalent parameter in configs/finetuning/rosa.yaml. +
+ + +### Step 5: Launch Panza! + + +- To run Panza after a full training run, try something like `CUDA_VISIBLE_DEVICES=0 python3 runner.py user=USERNAME interfaces=cli writer/llm=transformers`. +- To run Panza after a RoSA or LoRA training run, replace `writer/llm=transformers` with `writer/llm=peft` TODO Armand: can we fix this? + + +:email: **Have fun with your new email writing assistant!** :email: + + + + +## :microscope: Advanced usage +- [Data Preparation Guide](./scripts/README.md#data-guide) +- [Hyper-Parameter Tuning Guide](./scripts/README.md#hyper-parameter-tuning-guide) +- [Prompt Preambles Tutorial](prompt_preambles/README.md) + +## :woman_technologist: Contributing +If you liked our work and want to contribute to improve the system, please feel free to do so! Make a _fork_ of our repository and once you have made your changes, submit a pull request so that we can review! + +One thing to mention: we want to make sure that we all adhere to the same coding standards, so we have added Black, a code formatter, as a prehook. To ensure that all your files are formatted with Black, do the following: + +1. Install the necessary dependencies +``` +pip install .[contributing] +``` + +2. Run the precommit command +``` +pre-commit install +``` + +3. Continue adding code as usual. All your code will be formatted by Black before commiting! + +## Authors + +Panza was conceived by Nir Shavit and Dan Alistarh and built by the [Distributed Algorithms and Systems group](https://ist.ac.at/en/research/alistarh-group/) at IST Austria. The contributors are (in alphabetical order): + +Dan Alistarh, Eugenia Iofinova, Eldar Kurtic, Ilya Markov, Armand Nicolicioiu, Mahdi Nikdan, Andrei Panferov, and Nir Shavit. + +Contact: dan.alistarh@ist.ac.at + +We thank our collaborators Michael Goin and Tony Wang at NeuralMagic and MIT for their helpful testing and feedback. diff --git a/TEMP_HOW_TO_RUN_INFERENCE.md b/TEMP_HOW_TO_RUN_INFERENCE.md new file mode 100644 index 0000000..da42f9d --- /dev/null +++ b/TEMP_HOW_TO_RUN_INFERENCE.md @@ -0,0 +1,46 @@ +# How to run inference in Panza3 + +There are two backend options: Ollama (no GPU) or Local (with GPU). The dependencies necessary for each backend are different. + +## Step 1: Install Dependencies for Panza + +For Ollama, simply run: +```bash +pip install -e . +``` + +For Local, run: +```bash +pip install -e . +``` +and +```bash +pip install panza_mail[training] +``` + +## Step 2a: Ollama Prerequisites + +If running with Ollama, then Ollama needs to be installed from the [web page](https://ollama.com/). + +Then, you will need to convert your model into a GGUF file. + +## Step 2b: Local Prerequisites + +If running locally, then the Panza model needs to be located in `data`. + +## Step 3: Set configurations + +In the `configs folder` add a user YAML file for yourself in `/user`. + +If running with Ollama, edit the `name` and `gguf` fields in `/writer/llm/ollama.yaml` with a name of your choice and the path to the GGUF file. + +## Step 4: Run Panza + +To run Panza, cd into the `scripts` directory and run: +```bash +python3 runner.py user= interfaces= writer/llm= +``` +For example, to run with Ollama and the CLI interface with the user `test`, run: +```bash +python3 runner.py user=test interfaces=cli writer/llm=ollama +``` \ No newline at end of file diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..0a09b04 --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,9 @@ +defaults: + - user: default + +panza_workspace: ${hydra:runtime.cwd}/../ +checkpoint_dir: ${panza_workspace}/checkpoints +seed: 41 + +embedding_model: "sentence-transformers/all-mpnet-base-v2" +model_precision: bf16 # bf16 or fp32 diff --git a/src/panza/finetuning/configs/rosa_panza.yaml b/configs/finetuning/base.yaml similarity index 50% rename from src/panza/finetuning/configs/rosa_panza.yaml rename to configs/finetuning/base.yaml index f2cd3b7..59eb3fb 100644 --- a/src/panza/finetuning/configs/rosa_panza.yaml +++ b/configs/finetuning/base.yaml @@ -1,50 +1,33 @@ +wandb_disabled: true # We assume that wandb is disabled unless the user has logged on. + max_seq_len: 512 -global_seed: 17 -model_name_or_path: #TODO +global_seed: ${seed} load_path: # set via bash script to be absolute path to your sparse checkpoint precision: amp_bf16 -hf_save_path: ./checkpoints +hf_save_path: ${checkpoint_dir}/models -max_duration: # TODO eval_interval: 1 -seed: ${global_seed} -global_train_batch_size: #TODO -device_train_microbatch_size: 16 -device_eval_batch_size: 16 +global_train_batch_size: 8 +device_train_microbatch_size: 1 +device_eval_batch_size: 1 -run_name: # If left blank, will be read from env var $RUN_NAME +run_name: # If left blank, it will be generated based on configs model: name: hf_causal_lm pretrained: true - pretrained_model_name_or_path: ${model_name_or_path} - max_seq_len: ${max_seq_len} + pretrained_model_name_or_path: ${finetuning.model_name_or_path} + max_seq_len: ${finetuning.max_seq_len} output_hidden_states: true - weight_bias_dtype: #TODO + weight_bias_dtype: ${model_precision} compute_dtype: bf16 -rosa: - lora_r: #TODO - spa_d: #TODO - lora_alpha: 16 - target_modules: 'all-linear' - lora_dropout: 0.05 - impl: auto - spa_store_transpose: true - rosa_dtype: bf16 - spa_num_grads: 1 - grad_acc_mode: mean_squared - mask_load_path: #TODO - mask_save_path: #TODO - terminate_after_mask_generation: #TODO - schedule: #TODO - tokenizer: - name: ${model_name_or_path} + name: ${finetuning.model_name_or_path} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${finetuning.max_seq_len} train_loader: name: finetuning @@ -52,9 +35,9 @@ train_loader: hf_name: json split: train hf_kwargs: - data_files: #TODO - preprocessing_fn: preprocessing:panza_preprocessing_function - max_seq_len: ${max_seq_len} + data_files: ${user.data_dir}/train.jsonl + preprocessing_fn: panza.finetuning.preprocessing:panza_preprocessing_function + max_seq_len: ${finetuning.max_seq_len} allow_pad_trimming: false decoder_only_format: true shuffle: true @@ -72,7 +55,7 @@ scheduler: optimizer: name: decoupled_adamw - lr: # TODO + lr: 1e-5 betas: - 0.9 - 0.999 diff --git a/configs/finetuning/full.yaml b/configs/finetuning/full.yaml new file mode 100644 index 0000000..bc4c9be --- /dev/null +++ b/configs/finetuning/full.yaml @@ -0,0 +1,29 @@ +defaults: + - base + + +max_duration: 3ep +lr: 1e-5 +batch_size: 8 +eval_interval: 1 +seed: ${seed} +model_name_or_path: "ISTA-DASLab/Meta-Llama-3-8B-Instruct" + +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: FULL + activation_checkpointing: true + activation_checkpointing_reentrant: false + activation_cpu_offload: false + limit_all_gathers: true + verbose: false + +callbacks: + hf_checkpointer: + overwrite: true + precision: # TODO + save_folder: ${finetuning.hf_save_path}/${finetuning.run_name} + save_interval: 1dur + +scheduler: + t_warmup: 20ba diff --git a/configs/finetuning/rosa.yaml b/configs/finetuning/rosa.yaml new file mode 100644 index 0000000..3f614ec --- /dev/null +++ b/configs/finetuning/rosa.yaml @@ -0,0 +1,32 @@ +defaults: + - base + +max_duration: 5ep +lr: 1e-5 +batch_size: 8 +eval_interval: 1 +seed: ${seed} +model_name_or_path: "ISTA-DASLab/Meta-Llama-3-8B-Instruct" + +rosa: + lora_lr: ${finetuning.lr} + lora_r: 8 + spa_d: 0.01 + lora_alpha: 16 + target_modules: 'all-linear' + lora_dropout: 0.05 + impl: auto + spa_store_transpose: true + rosa_dtype: bf16 + spa_num_grads: 1 + grad_acc_mode: mean_squared # 'mean' or 'mean_squared': how to accumulate gradients + mask_load_path: #TODO + mask_save_path: #TODO + terminate_after_mask_generation: #TODO + schedule: #TODO + masks_only: true + +scheduler: + t_warmup: 8ba + +num_cpu_threads: 1 diff --git a/configs/interfaces/cli.yaml b/configs/interfaces/cli.yaml new file mode 100644 index 0000000..1c18948 --- /dev/null +++ b/configs/interfaces/cli.yaml @@ -0,0 +1 @@ +_target_: panza.interface.PanzaCLI \ No newline at end of file diff --git a/configs/interfaces/gui.yaml b/configs/interfaces/gui.yaml new file mode 100644 index 0000000..8a6497d --- /dev/null +++ b/configs/interfaces/gui.yaml @@ -0,0 +1 @@ +_target_: panza.interface.PanzaGUI \ No newline at end of file diff --git a/configs/interfaces/json.yaml b/configs/interfaces/json.yaml new file mode 100644 index 0000000..a91f254 --- /dev/null +++ b/configs/interfaces/json.yaml @@ -0,0 +1,9 @@ +input_file: ${panza_workspace}/data/test.jsonl +batch_size: 8 +use_thread: false +responses_per_prompt: 1 +checkpoint: ${checkpoint} +panza_workspace: ${panza_workspace} +compute_metrics: true +username: ${user.username} +_target_: panza.interface.PanzaJSON \ No newline at end of file diff --git a/configs/panza_finetuning.yaml b/configs/panza_finetuning.yaml new file mode 100644 index 0000000..3cab7a6 --- /dev/null +++ b/configs/panza_finetuning.yaml @@ -0,0 +1,7 @@ + +defaults: + - base + - finetuning: full + # For preprocessing (i.e., assembling the LLM prompt, inherit the defaults + # from the writer (the inference module).) + - writer/prompting/email_prompting@preprocessing.prompting \ No newline at end of file diff --git a/configs/panza_preparation.yaml b/configs/panza_preparation.yaml new file mode 100644 index 0000000..550e732 --- /dev/null +++ b/configs/panza_preparation.yaml @@ -0,0 +1,25 @@ +defaults: + - base + - writer: summary + - writer/prompting/retriever/faiss@retriever + +batch_size: 8 + +email_dump_path: ${user.data_dir}/Sent.mbox +cleaned_emails_path: ${user.data_dir}/${user.username}_emails_clean.jsonl +discarded_emails_dir: ${user.data_dir}/${user.username}/discarded_emails +summarized_emails_path: ${user.data_dir}/${user.username}_emails_clean_summarized.jsonl + +rag_db_dir: ${user.data_dir} + +checkpoint: "microsoft/Phi-3-mini-4k-instruct" +force_extract_clean_emails: false # If false, data will not be recreated if it already exists. + + # Parameters for train-test split, if required. +test_split: 0. +split_type: random # Options are 'random', 'chronological'. + +# Parameters for RAG database. +rag_embedding_chunk_size: 3000 +rag_embedding_chunk_overlap: 3000 +rag_embedding_model: "sentence-transformers/all-mpnet-base-v2" \ No newline at end of file diff --git a/configs/panza_writer.yaml b/configs/panza_writer.yaml new file mode 100644 index 0000000..2faf9f2 --- /dev/null +++ b/configs/panza_writer.yaml @@ -0,0 +1,12 @@ +defaults: + - base + - writer: email + - interfaces: + # - gui + # - cli + # - web + - json + +# Either a full path to the checkpoint, or the 'latest' tag, +# Which looks for the latest checkpoint in `checkpoint_dir` +checkpoint: 'latest' \ No newline at end of file diff --git a/configs/user/default.yaml b/configs/user/default.yaml new file mode 100644 index 0000000..8f2b9e6 --- /dev/null +++ b/configs/user/default.yaml @@ -0,0 +1,9 @@ +email_address: "abc@xyz.com" # Change this to your email address! +username: "abc" # This identifies the user in the users directory and the names of the emails files. + +data_dir: ${panza_workspace}/data + +system_preamble_path: ${panza_workspace}/prompt_preambles/system_preamble.txt +user_preamble_path: ${panza_workspace}/prompt_preambles/user_preamble.txt +rag_preamble_path: ${panza_workspace}/prompt_preambles/rag_preamble.txt +thread_preamble_path: ${panza_workspace}/prompt_preambles/thread_preamble.txt diff --git a/configs/writer/email.yaml b/configs/writer/email.yaml new file mode 100644 index 0000000..49de780 --- /dev/null +++ b/configs/writer/email.yaml @@ -0,0 +1,5 @@ +defaults: + - llm: transformers + - prompting: email_prompting + +_target_: panza.writer.PanzaWriter diff --git a/configs/writer/llm/peft.yaml b/configs/writer/llm/peft.yaml new file mode 100644 index 0000000..2c0a892 --- /dev/null +++ b/configs/writer/llm/peft.yaml @@ -0,0 +1,9 @@ +defaults: + - sampling: random + +_target_: panza.llm.PeftLLM +name: ${checkpoint} +checkpoint: ${checkpoint} +device: "cuda" # Alternatively, "cuda" +dtype: "fp32" +load_in_4bit: false diff --git a/configs/writer/llm/sampling/greedy.yaml b/configs/writer/llm/sampling/greedy.yaml new file mode 100644 index 0000000..5854169 --- /dev/null +++ b/configs/writer/llm/sampling/greedy.yaml @@ -0,0 +1,2 @@ +do_sample: false +max_new_tokens: 1024 \ No newline at end of file diff --git a/configs/writer/llm/sampling/random.yaml b/configs/writer/llm/sampling/random.yaml new file mode 100644 index 0000000..e5a954a --- /dev/null +++ b/configs/writer/llm/sampling/random.yaml @@ -0,0 +1,5 @@ +do_sample: true +temperature: 0.7 +top_k: 50 +top_p: 0.7 +max_new_tokens: 1024 \ No newline at end of file diff --git a/configs/writer/llm/transformers.yaml b/configs/writer/llm/transformers.yaml new file mode 100644 index 0000000..305dc25 --- /dev/null +++ b/configs/writer/llm/transformers.yaml @@ -0,0 +1,9 @@ +defaults: + - sampling: random + +_target_: panza.llm.TransformersLLM +name: ${checkpoint} +checkpoint: ${checkpoint} +device: "cuda" +dtype: "fp32" +load_in_4bit: false \ No newline at end of file diff --git a/configs/writer/prompting/email_prompting.yaml b/configs/writer/prompting/email_prompting.yaml new file mode 100644 index 0000000..27b458a --- /dev/null +++ b/configs/writer/prompting/email_prompting.yaml @@ -0,0 +1,13 @@ +defaults: + - retriever: faiss + +_target_: panza.prompting.EmailPromptBuilder + +system_preamble: ${load_preamble:${user.system_preamble_path}} +user_preamble: ${load_user_preamble:${user.user_preamble_path}} +rag_preamble: ${load_preamble:${user.rag_preamble_path}} +thread_preamble: ${load_preamble:${user.thread_preamble_path}} + +number_rag_emails: 0 +rag_relevance_threshold: 0.2 +number_thread_emails: 0 \ No newline at end of file diff --git a/configs/writer/prompting/retriever/faiss.yaml b/configs/writer/prompting/retriever/faiss.yaml new file mode 100644 index 0000000..fbf35c9 --- /dev/null +++ b/configs/writer/prompting/retriever/faiss.yaml @@ -0,0 +1,5 @@ +_target_: panza.retriever.FaissRetriever +db_path: ${user.data_dir} +index_name: ${user.username} +embedding_model: ${embedding_model} +device: "cpu" \ No newline at end of file diff --git a/configs/writer/prompting/retriever/none.yaml b/configs/writer/prompting/retriever/none.yaml new file mode 100644 index 0000000..8015504 --- /dev/null +++ b/configs/writer/prompting/retriever/none.yaml @@ -0,0 +1 @@ +_target_: panza.retriever.NoneRetriever \ No newline at end of file diff --git a/configs/writer/prompting/summarization_prompting.yaml b/configs/writer/prompting/summarization_prompting.yaml new file mode 100644 index 0000000..98449cc --- /dev/null +++ b/configs/writer/prompting/summarization_prompting.yaml @@ -0,0 +1,3 @@ +_target_: panza.prompting.SummarizationPromptBuilder + +summarization_prompt: ${load_preamble:${panza_workspace}/prompt_preambles/summarization_prompt.txt} diff --git a/configs/writer/summary.yaml b/configs/writer/summary.yaml new file mode 100644 index 0000000..827c4dd --- /dev/null +++ b/configs/writer/summary.yaml @@ -0,0 +1,5 @@ +defaults: + - llm: transformers + - prompting: summarization_prompting + +_target_: panza.writer.PanzaWriter diff --git a/prepare_env.sh b/prepare_env.sh index b6e6972..61e106b 100644 --- a/prepare_env.sh +++ b/prepare_env.sh @@ -1,13 +1,8 @@ # crash in case of error trap 'trap - ERR RETURN; kill -INT $$ ; return' ERR RETURN -conda create --name panza python=3.10 -y -conda activate panza +conda create --name panza_refactor python=3.10 -y +conda activate panza_refactor -conda install pytorch==2.2.2 torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y - -pip install langdetect langchain langchain-community sentence-transformers faiss-cpu fire mauve-text evaluate torchmetrics gradio cmake packaging nltk - -pip install git+https://github.com/IST-DASLab/llm-foundry -pip install git+https://github.com/IST-DASLab/peft-rosa.git@grad_quant -pip install spops_sm_80 +# install dependencies based on pyproject.toml +pip install -e .[training] \ No newline at end of file diff --git a/src/panza/data_preparation/summarization_prompt.txt b/prompt_preambles/summarization_prompt.txt similarity index 100% rename from src/panza/data_preparation/summarization_prompt.txt rename to prompt_preambles/summarization_prompt.txt diff --git a/prompt_preambles/user_preamble.txt b/prompt_preambles/user_preamble.txt index 67b0385..cbc16a5 100644 --- a/prompt_preambles/user_preamble.txt +++ b/prompt_preambles/user_preamble.txt @@ -6,4 +6,4 @@ My address is 123 Main Street, Springfield, IL, USA. My boss's name is Alex Burns. My children's names are Elsa, Anna, and Olaf. I am deeply committed to my hobby of underwater basket weaving, for which -we meet every Thursday at noon. +we meet every Thursday at noon. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9ac7878 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[project] +name = "panza_mail" +version = "2024.08.14" +description = "A personal email assistant, trained and running on-device." +dependencies = [ + "torch==2.2.2", + "omegaconf", + "fastapi", + "uvicorn", + "pydantic", + "python-dotenv", + "hydra-core", + "langchain", + "langchain-community", + "sentence-transformers", + "faiss-cpu", + "gradio", +] + +[project.optional-dependencies] +training = [ + "langdetect", + "fire", + "mauve-text", + "evaluate", + "torchmetrics", + "gradio", + "cmake", + "packaging", + "nltk", + "llm-foundry@git+https://github.com/IST-DASLab/llm-foundry", + "peft@git+https://github.com/IST-DASLab/peft-rosa.git@grad_quant_looser_versioning", + "spops-sm-80", +] +contributing = [ + "pre-commit", +] + +[build-system] +requires = ["setuptools >= 61.0.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.black] +line-length = 100 + +[tool.pytest.ini_options] +pythonpath = ["src"] + +[dev-dependencies] +pytest = "*" \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md index bc77e78..00e2e91 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -8,21 +8,48 @@ This directory contains all scripts necessary to train and run Panza. We provide * `config.sh` sets the necessary environment variables and other parameters used throughout the Panza workflow. This script should be edited by the user in several places: to set the user's email address (for data preprocessing), to select the LLM used for data summarization and Panza finetuning, and optionally to update the locations the data and models will be stored. #### Data preparation -* `extract_emails.sh` extracts the user's emails from the `.mbox` file and removes any unusable ones (such as email forwards, those that seem to be written in a foreign language, or those that are too short). -* `prepare_dataset.sh` automatically converts emails to training data by using an LLM to write their summaries in the form of prompts; it then splits them into train and test data, and prepares the RAG database. +* `prepare_data.py` does several things: + +1. Extracts the user's emails from the `.mbox` file and removes any unusable ones (such as email forwards, those that seem to be written in a foreign language, or those that are too short). +1. Automatically converts emails to training and test data by using an LLM to write their summaries in the form of prompts. +1. Optionally, splits the summarized into train and test data. This is not done by default because we expect most users to use the default hyperparameters, and therefore have no need for evaluation. To activate this feature, indicate the size of the test split as follows: `python ./prepare_data.py test_split=0.2` +1. Prepares the RAG database. Note that only train data is used for this step. #### Training -* `train_rosa.sh` performs [parameter-efficient training](https://arxiv.org/pdf/2401.04679.pdf), and evaluation. For evaluation, we use a heldout email dataset and compute the BLEU score between the output email and the one originally written by the user. -* `train_fft.sh` performs full-parameter/full-rank training, and then evaluation (as before). _Note that this requires additional computational resources (about 2x)._ +* `train_rosa.sh` performs [parameter-efficient training](https://arxiv.org/pdf/2401.04679.pdf). +* `train_fft.sh` performs full-parameter/full-rank training. _Note that this requires additional computational resources (about 2x)._ + + +#### Inference/Serving + +Serving is done through the `runner` object. To use the runner, the type of model and the type of interface must be specified. + +For interfaces, we offer serving via CLI (command-line inference) and an online GUI (via Gradio), as well as a bulk-serving API via JSON for the JSON, the location of the file defaults to the test data, but can be overridden (see the "evaluation" section, below). + +Currently, we support full-finetuned and parameter-efficienty-finetuned models. These must be set through the `writer-llm` parameter. +* To serve a foundation (i.e., not locally-finetuned) model or a fully-finetuned model, set `writer/llm=transformers` +* To serve a PEFT model, set `writer/llm=peft` + +Thus, a serving command would look something like: + +``` +python runner.py user=[username] interfaces=[cli|gui] writer/llm=[peft|transformers] checkpoint=[checkpoint_loc] +``` + +For the json interface, it would look like: + +``` +python runner.py user=[username] interfaces=json writer/llm=[peft|transformers] checkpoint=[checkpoint_loc] interfaces.input_file=[json_file_loc] +``` + +##### Evaluation -#### Serving -* `run_panza_cli.sh` runs a simple tool in the command line that enables a user to put in prompts and get Panza responses. -* `run_panza_gui.sh` runs a simple tool in the browser that enables a user to put in prompts and get Panza responses. +We think of evaluation as a special form of bulk inference/serving. Thus, like other forms of inference, it is done through a runner, specifically through the `json` interface. -Both of these tools require a link to the model that you wish to use. Running without providing a `MODEL` argument will run inference on the base (non-finetuned) LLM. +A sample command that runs interface over the test set looks like: ``` -./run_panza_gui.sh MODEL= +python runner.py user=jen interfaces=json writer/llm=[peft|transformers] checkpoint=[checkpoint_loc] interfaces.input_file=../data/test.jsonl ``` @@ -32,18 +59,18 @@ Both of these tools require a link to the model that you wish to use. Running wi :bulb: We recommend having between 128 and 1000 sent emails as training targets. Less than 128 might cause the model to overfit, while we haven't found that more than 1000 emails help for the style transfer. However, we encourage you to include as many emails as available in the RAG database, as they will provide the model with additional context. To sub-select training data, you can perform the usual flow with all of your data (export, run `extract_emails.sh` and `prepare_dataset.sh`), and then simply remove all but your target number of rows from the resulting `train.jsonl` in the `data`. -:bulb: To merge data from multiple mailboxes (such as combining your personal and work emails), run `extract_emails.sh` on each `.mbox` file, remembering to change the value of `PANZA_EMAIL_ADDRESS` in `config.sh` for every inbox. Then simply concatenate the resulting `[email_id].clean.jsonl` files to one, and use that file's `email_id` for the `PANZA_EMAIL_ADDRESS` argument in `config.sh` going forward. Make sure that the `prepare_dataset.sh` script is run _after_ the merge. +:bulb: To merge data from multiple mailboxes (such as combining your personal and work emails), run `extract_emails.sh` on each `.mbox` file, remembering to change the value of `user.email_address` and `user.user_name` in `config.sh` for every inbox. Then simply concatenate the resulting `[user.user_name].clean.jsonl` files to one, and use that file's `user.user_name` going forward. Make sure that the `prepare_dataset.sh` script is run _after_ the merge with `force_extract_clean_emails=false`. ### Hyper-Parameter Tuning Guide To get the most out of Panza, it is essential to find good hyper-parameters for the fine-tuning process. -Specifically the key parameters to consider are the learning rates (`LR` and `LORA_LR`, in the case of RoSA fine-tuning) and (`NUM_EPOCHS`) parameters, whose values should be adjusted based on your amount of data and model in use. +Specifically the key parameters to consider are the learning rates (`trainer.optimizer.lr=0.1` and `trainer.optimizer.rosa.lora_lr`, in the case of RoSA fine-tuning) and (`trainer.optimizer.max_duration`) parameters, whose values should be adjusted based on your amount of data and model in use. Here are some general guidelines for hyper-parameter fine-tuning: * In our experience, a good target for the Perplexity over the training set (displayed during and at the end of the training run) is in the range 1-1.5 (for full fine-tuning) to 2-3 (for RoSA tuning). At that point, Panza should be able to reproduce your writing style quite faithfully. -* To reach this target, you can ajust two parameters: the length of training (`NUM_EPOCHS`) and the learning rates (`LR` for full fine-tuning and `LR` and `LORA_LR` for RoSA). +* To reach this target, you can ajust two parameters: the length of training (`trainer.optimizer.max_duration`) and the learning rates (`trainer.optimizer.lr` for full fine-tuning and `trainer.optimizer.lr` and `trainer.optimizer.rosa.lora_lr` for RoSA). * Specifically, for full fine-tuning we have found 3 training epochs to be sufficient. For RoSA fine-tuning, one usually needs 5-7 epochs for best results. -* Regarding the learning rates, we have already provided stable default values (around 1e-5 for both LLaMA3-8B and Mistral). You may adjust these depending on the amount of your local data. +* Regarding the learning rates, we have already provided stable default values (around 1e-5 for LLaMA3-8B , Phi-3.5-mini, and Mistral). You may adjust these depending on the amount of your local data. * We have found that setting these values too low will yield default "impersonal'' answers (specifically, the same answers as the base model with some context). Setting them too high will lead the model to "overfit" to the user data, to the point where a lot of the latent model "knowledge" is lost. The key to good performance is to find a good middle ground between these two scenarios. diff --git a/scripts/config.sh b/scripts/config.sh deleted file mode 100755 index 08b1479..0000000 --- a/scripts/config.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -export PANZA_EMAIL_ADDRESS="firstname.lastname@gmail.com" # Change this to your email address! -export PANZA_USERNAME="${PANZA_EMAIL_ADDRESS%@*}" # Removes everything after @; for the example above, it will be firstname.lastname - -export PANZA_WORKSPACE=$(dirname "$(dirname "$(realpath "$0")")"); -export PANZA_DATA_DIR="$PANZA_WORKSPACE/data" # where data is stored -export PANZA_CHECKPOINTS="$PANZA_WORKSPACE/checkpoints" # where checkpoints are stored -export PANZA_FINETUNE_CONFIGS="$PANZA_WORKSPACE/src/panza/finetuning/configs" # where training configuration details are stored - -export PANZA_PREAMBLES="$PANZA_WORKSPACE/prompt_preambles" # this is where the system prompt and user prompt preambles can be accessed; you will need to edit these -export PANZA_SYSTEM_PREAMBLE_PATH="$PANZA_PREAMBLES/system_preamble.txt" # system prompt -# IMPORTANT: Please edit the user preamble (at the PANZA_USER_PREAMBLE_PATH) if you plan to use it (recommended). -export PANZA_USER_PREAMBLE_PATH="$PANZA_PREAMBLES/user_preamble.txt" # a useful preamble to the user instruction, explaining what's going on to the LLM -export PANZA_RAG_PREAMBLE_PATH="$PANZA_PREAMBLES/rag_preamble.txt" # a preamble for the RAG component -export PANZA_THREAD_PREAMBLE_PATH="$PANZA_PREAMBLES/thread_preamble.txt" # a preamble for the RAG component - -export PANZA_SUMMARIZATION_BATCH_SIZE=8 # batch size for summarization. -export PANZA_EVALUATION_BATCH_SIZE=1 # batch size for evaluation. Can safely be set to higher value (e.g., 8) if the GPU has enough capacity. - -export MODEL_PRECISION=bf16 # precision at which the base model is stored; options: bf16, fp32, or '4bit' -# export PANZA_GENERATIVE_MODEL="mistralai/Mistral-7B-Instruct-v0.2" -export PANZA_GENERATIVE_MODEL="ISTA-DASLab/Meta-Llama-3-8B-Instruct" -# export PANZA_GENERATIVE_MODEL="microsoft/Phi-3-mini-4k-instruct" - -lowercased=$(echo "$PANZA_GENERATIVE_MODEL" | tr '[:upper:]' '[:lower:]') -if [[ ${lowercased} == *llama* ]]; then - export MODEL_TYPE=llama3 -elif [[ ${lowercased} == *mistral* ]]; then - export MODEL_TYPE=mistralv2 -elif [[ ${lowercased} == *phi* ]]; then - export MODEL_TYPE=phi3 -else - echo "Model type ${PANZA_GENERATIVE_MODEL} not recognized! Panza only works with Mistral and Llama3 models. Exiting." - exit -fi - -export PANZA_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # embedding model for RAG; can be changed, trading off speed for quality - -export PANZA_RAG_RELEVANCE_THRESHOLD=0.2 # emails whose relevance is above this threshold will be presented for RAG - -export PANZA_SEED=42 # the one true seed - -export PANZA_FINETUNE_WITH_PREAMBLE=1 # states whether user and system preambles are used for fine-tuning; on by default -export PANZA_FINETUNE_WITH_RAG=0 # states whether RAG preambles are used for fine-tuning; off by default -export PANZA_FINETUNE_WITH_THREAD=0 # states whether the email thread is used for fine-tuning; off by default -export PANZA_FINETUNE_RAG_NUM_EMAILS=3 # maximum number of emails to use for RAG fine-tuning; 3 by default -export PANZA_FINETUNE_RAG_PROB=0.55 # probability of using RAG context for fine-tuning; 0.5 by default -export PANZA_FINETUNE_RAG_RELEVANCE_THRESHOLD=0.2 # emails whose relevance is above this threshold will be presented for RAG during fine-tuning -export PANZA_FINETUNE_THREAD_NUM_EMAILS=3 # maximum number of emails to use for thread fine-tuning; 3 by default -export PANZA_DISABLE_RAG_INFERENCE=0 # RAG inference is on by default, since it's usually better - -export PANZA_WANDB_DISABLED=True # disable Weights and Biases logging by default - -export PYTHONPATH="$PANZA_WORKSPACE/src:$PYTHONPATH" - -# Optionally, set your HF_HOME and/or TRANSFORMERS_CACHE here. -# export HF_HOME= -# export TRANSFORMERS_CACHE= diff --git a/scripts/extract_emails.sh b/scripts/extract_emails.sh deleted file mode 100755 index 23f5300..0000000 --- a/scripts/extract_emails.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -source config.sh - -MBOX_NAME="Sent.mbox" -MBOX_PATH="${PANZA_DATA_DIR}/${MBOX_NAME}" - -python ../src/panza/data_preparation/extract_emails.py \ - --mbox-path=${MBOX_PATH} \ - --output-path=${PANZA_DATA_DIR} \ - --email=${PANZA_EMAIL_ADDRESS} \ \ No newline at end of file diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py new file mode 100644 index 0000000..d279ea5 --- /dev/null +++ b/scripts/prepare_data.py @@ -0,0 +1,168 @@ +import datetime +import json +import logging +import os +import random +import shutil +import time +from typing import List + +import hydra +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm + +from panza import PanzaWriter # The import also loads custom Hydra resolvers +from panza.entities import Document, Email, SummarizationInstruction +from panza.retriever import DocumentRetriever +from panza.data_preparation.extract_emails import extract_emails +from panza.data_preparation.rag import create_vector_store + +LOGGER = logging.getLogger(__name__) + + +def rename_config_keys(cfg: DictConfig) -> None: + # Disable struct mode to allow modifications + OmegaConf.set_struct(cfg, False) + + cfg.writer.llm.sampling_parameters = cfg.writer.llm.sampling + del cfg.writer.llm.sampling + + cfg.writer.prompt_builder = cfg.writer.prompting + del cfg.writer.prompting + + # Re-enable struct mode to lock down the configuration + OmegaConf.set_struct(cfg, True) + + +def load_documents(data_path: str) -> None: + assert data_path.endswith(".jsonl"), f"Expecting a .jsonl file, but given = {data_path}" + + LOGGER.info(f"--> Reading emails from: {data_path}") + + with open(data_path, "r") as f: + lines = f.readlines() + documents = [Email.deserialize(line.strip(",")) for line in lines] + print(f"--> # emails = {len(documents)}") + + return documents + + +def generate_synthetic_instructions( + documents: List[Document], writer: PanzaWriter, batch_size: int, output_path: str +) -> None: + num_processed_documents = 0 + num_batches = (len(documents) - 1) // batch_size + 1 + start_time = time.time() + with open(output_path, "w") as f: + for i in tqdm(range(0, len(documents), batch_size)): + print(f"--> Processing batch {i // batch_size + 1}/{num_batches}") + batch = documents[i : i + batch_size] + instructions = [ + SummarizationInstruction(instruction=document.email) for document in batch + ] + + summaries = writer.run_batch(instructions) + num_processed_documents += len(summaries) + + for it, summary in enumerate(summaries): + # Considerf adding cleaning and filtering here. + batch[it].summary = summary + + # Write the summarized documents to a file + for document in batch: + f.write(json.dumps(document.serialize())) + f.write("\n") + + elapsed_time = time.time() - start_time + LOGGER.info(f"--> Processed {num_processed_documents} documents in {elapsed_time:.2f} seconds.") + + +def check_if_file_exists(cfg: DictConfig) -> None: + if os.path.exists(cfg.cleaned_emails_path) and not cfg.force_extract_clean_emails: + LOGGER.warning( + f"Cleaned email file already exists, using existing file {cfg.cleaned_emails_path}. " + "If you want to regenerate use the flag force_extract_clean_emails=true." + ) + return True + return False + + +def split_and_write_data(cfg): + if cfg.test_split == 0: + shutil.copy(cfg.summarized_emails_path, os.path.join(cfg.user.data_dir, "train.jsonl")) + # Bad hack - we need test data for the training to work. + shutil.copy(cfg.summarized_emails_path, os.path.join(cfg.user.data_dir, "test.jsonl")) + else: + with open(cfg.summarized_emails_path, "r") as f: + data = f.readlines() + if cfg.split_type == "random": + random.seed(cfg.seed) + random.shuffle(data) + elif cfg.split_type == "chronological": + data = sorted(data, key=lambda x: datetime.fromisoformat(json.loads(x)["date"])) + else: + raise ValueError("Invalid split type.") + + train_size = int(len(data) * 1 - cfg.test_split) + + with open(os.path.join(cfg.user.data_dir, "train.jsonl"), "w") as f: + for i in range(train_size): + f.write(data[i]) + + with open(os.path.join(cfg.user.data_dir, "test.jsonl"), "w") as f: + for i in range(train_size, len(data)): + f.write(data[i]) + + +@hydra.main(version_base="1.1", config_path="../configs", config_name="panza_preparation") +def main(cfg: DictConfig) -> None: + LOGGER.info("Running Panza Data Preparation") + LOGGER.info("Configuration: \n%s", OmegaConf.to_yaml(cfg, resolve=True)) + + # Rename config keys to follow class structure + rename_config_keys(cfg) + + # Skip running if already exist + if not check_if_file_exists(cfg): + # Extract the emails from the .mbox file + extract_emails( + cfg.email_dump_path, + cfg.cleaned_emails_path, + [cfg.user.email_address], + cfg.discarded_emails_dir, + ) + + # Instantiate Panza writer + writer: PanzaWriter = hydra.utils.instantiate(cfg.writer) + assert isinstance(writer, PanzaWriter), "Failed to instantiate PanzaWriter" + + # Instantiate retriever + retriever: DocumentRetriever = hydra.utils.instantiate(cfg.retriever) + assert isinstance(retriever, DocumentRetriever), "Failed to instantiate DocumentRetriever" + retriever.set_document_class(Email) + + # Load documents + documents = load_documents(cfg.cleaned_emails_path) + generate_synthetic_instructions( + documents=documents, + writer=writer, + batch_size=cfg.batch_size, + output_path=cfg.summarized_emails_path, + ) + + # Write the test data to test.jsonl, with an optional train-test split + split_and_write_data(cfg) + + # Use only the training data (which might be all the data) for RAG. + create_vector_store( + os.path.join(cfg.user.data_dir, "train.jsonl"), + cfg.rag_embedding_chunk_size, + cfg.rag_embedding_chunk_overlap, + cfg.rag_db_dir, + cfg.user.username, + cfg.rag_embedding_model, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/prepare_data.sh b/scripts/prepare_data.sh new file mode 100755 index 0000000..0207732 --- /dev/null +++ b/scripts/prepare_data.sh @@ -0,0 +1,31 @@ +# Convenience script for data preparation +# All arguments to the python script can be provided +# here exactly in the form they would be passed to the +# python script directly. +# +# Example usage: +# CUDA_VISIBLE_DEVICES=x ./prepare_data.sh user=alonso + +set -e + +vars=() +# Set a default for the required user argument. We'll override it +# later if provided. +vars[1]=$"user=$(whoami)" +idx=2 + +# process input arguments +for argument in "$@" +do + key=$(echo $argument | cut -f1 -d=) + + if [[ $key == user ]]; then + # We already set the default value here; change it now. + vars[1]=$argument + else + vars[idx]=$argument + idx+=1 + fi +done + +python ./prepare_data.py ${vars[@]} \ No newline at end of file diff --git a/scripts/prepare_dataset.sh b/scripts/prepare_dataset.sh deleted file mode 100755 index a10e324..0000000 --- a/scripts/prepare_dataset.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash - -source config.sh - -TRAIN_RATIO=0.8 -SPLIT_TYPE="chronological" # random or chronological - -CHUNK_SIZE=3000 -CHUNK_OVERLAP=3000 - -LOAD_IN_4BIT=0 -RUN_FP32=0 - -for ARGUMENT in "$@" -do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" -done - -USE_4BIT_QUANT=$([ "${LOAD_IN_4BIT}" = 1 ] && echo "--load-in-4bit" || echo "") -USE_FP32_COMPUTE=$([ "${RUN_FP32}" = 1 ] && echo "--fp32" || echo "") - -# Create synthetic instructions (summaries) for emails -python ../src/panza/data_preparation/summarize_emails.py \ - --path-to-emails="${PANZA_DATA_DIR}/${PANZA_USERNAME}_clean.jsonl" \ - --prompt-file="${PANZA_WORKSPACE}/src/panza/data_preparation/summarization_prompt.txt" \ - --batch-size=${PANZA_SUMMARIZATION_BATCH_SIZE} ${USE_4BIT_QUANT} ${USE_FP32_COMPUTE} && - -# Create train and test splits -python ../src/panza/data_preparation/split_data.py \ - --data-path="${PANZA_DATA_DIR}/${PANZA_USERNAME}_clean_summarized.jsonl" \ - --output-data-dir=${PANZA_DATA_DIR} \ - --train-ratio=${TRAIN_RATIO} \ - --split-type=${SPLIT_TYPE} \ - --seed=${PANZA_SEED} && - -# Create vector store with emails embeddings -python ../src/panza/data_preparation/create_vector_store.py \ - --path-to-emails="${PANZA_DATA_DIR}/train.jsonl" \ - --chunk-size=${CHUNK_SIZE} \ - --chunk-overlap=${CHUNK_OVERLAP} \ - --db-path=${PANZA_DATA_DIR} \ - --index-name=${PANZA_USERNAME} \ - --embedding_model=${PANZA_EMBEDDING_MODEL} diff --git a/scripts/prepare_train_eval.sh b/scripts/prepare_train_eval.sh new file mode 100755 index 0000000..c88782f --- /dev/null +++ b/scripts/prepare_train_eval.sh @@ -0,0 +1,63 @@ +# Convenience script for combining all data preparation, model training +# and model evaluation with json +# All arguments to the python script can be provided +# here exactly in the form they would be passed to the +# python script directly. +# +# Example usage: +# CUDA_VISIBLE_DEVICES=x ./prepare_train_eval.sh user=alonso finetuning=rosa + +set -e + +vars=() +# Set a default for the required user argument. We'll override it +# later if provided. +vars[1]=$"user=$(whoami)" +idx=2 + +# process input arguments +training_mode="tbd" # training_mode to be determined later. +test_split="0" +for argument in "$@" +do + key=$(echo $argument | cut -f1 -d=) + if [[ $key == user ]]; then + # We already set the default value here; change it now. + vars[1]=$argument + echo "Overriding user to be ${argument#*=}" + elif [[ $key == test_split ]]; then + test_split=${argument#*=} + echo "Setting the test_split to ${test_split}" + elif [[ $key == finetuning ]]; then + training_mode=${argument#*=} + echo "Setting finetuning mode to ${training_mode}" + elif [[ $training_mode == "rosa" ]] && [[ $key == finetuning.rosa.masks_only ]];then + echo "The 'finetuning.rosa.masks_only' argument is already set and should not be overridden here; override is ignored." + else + vars[idx]=$argument + idx+=1 + fi +done + +# Step 1. Prepare the data +python ./prepare_data.py ${vars[@]} +# Step 2 & 3 Combined. Determine the type of training to do and evaluate with json. +if [[ $training_mode == "rosa" ]]; then + # First create the masks for RoSA finetuning. + composer ../src/panza/finetuning/train.py \ + finetuning=rosa finetuning.rosa.masks_only=true ${vars[@]} + # Then train the weights. + composer ../src/panza/finetuning/train.py \ + finetuning=rosa finetuning.rosa.masks_only=false ${vars[@]} + if [[ $test_split != "0" ]]; then + echo "Generating json evaluation" + python runner.py interfaces=json writer/llm=peft + fi +elif [[ $training_mode == "full" ]]; then + composer ../src/panza/finetuning/train.py \ + finetuning=full ${vars[@]} + if [[ $test_split != "0" ]]; then + echo "Generating json evaluation" + python runner.py interfaces=json writer/llm=transformers + fi +fi \ No newline at end of file diff --git a/scripts/run_panza.sh b/scripts/run_panza.sh deleted file mode 100755 index 941515a..0000000 --- a/scripts/run_panza.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -source config.sh - -MODEL=${PANZA_GENERATIVE_MODEL} # Replace this with the checkpoint you want to use! - -for ARGUMENT in "$@" -do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" -done - -USE_RAG=$([ "${PANZA_DISABLE_RAG_INFERENCE}" = "1" ] && echo "" || echo "--use-rag") - -INFERENCE_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/gui_inference.py -python ${INFERENCE_SCRIPT} \ - --model=${MODEL} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --embedding-model=${PANZA_EMBEDDING_MODEL} \ - --db-path=${PANZA_DATA_DIR} \ - --index-name=${PANZA_USERNAME} \ - --rag-relevance-threshold=${PANZA_RAG_RELEVANCE_THRESHOLD} \ - ${USE_RAG} diff --git a/scripts/run_panza_cli.sh b/scripts/run_panza_cli.sh deleted file mode 100755 index 5a094db..0000000 --- a/scripts/run_panza_cli.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -source config.sh - -MODEL=${PANZA_GENERATIVE_MODEL} # Replace this with the checkpoint you want to use! - -for ARGUMENT in "$@" -do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" -done - -USE_RAG=$([ "${PANZA_DISABLE_RAG_INFERENCE}" = "1" ] && echo "" || echo "--use-rag") -USE_4BIT_QUANT=$([ "${MODEL_PRECISION}" = "4bit" ] && echo "--load-in-4bit" || echo "") - -INFERENCE_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/console_interactive_inference.py -python ${INFERENCE_SCRIPT} \ - --model=${MODEL} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --embedding-model=${PANZA_EMBEDDING_MODEL} \ - --db-path=${PANZA_DATA_DIR} \ - --index-name=${PANZA_USERNAME} \ - --rag-relevance-threshold=${PANZA_RAG_RELEVANCE_THRESHOLD} \ - ${USE_RAG} \ - ${USE_4BIT_QUANT} diff --git a/scripts/run_panza_gui.sh b/scripts/run_panza_gui.sh deleted file mode 100755 index 911c32c..0000000 --- a/scripts/run_panza_gui.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -source config.sh - -MODEL=${PANZA_GENERATIVE_MODEL} # Replace this with the checkpoint you want to use! - -for ARGUMENT in "$@" -do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" -done - -USE_RAG=$([ "${PANZA_DISABLE_RAG_INFERENCE}" = "1" ] && echo "" || echo "--use-rag") -USE_4BIT_QUANT=$([ "${MODEL_PRECISION}" = "4bit" ] && echo "--load-in-4bit" || echo "") - -INFERENCE_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/gui_inference.py -python ${INFERENCE_SCRIPT} \ - --model=${MODEL} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --embedding-model=${PANZA_EMBEDDING_MODEL} \ - --db-path=${PANZA_DATA_DIR} \ - --index-name=${PANZA_USERNAME} \ - --rag-relevance-threshold=${PANZA_RAG_RELEVANCE_THRESHOLD} \ - ${USE_RAG} \ - ${USE_4BIT_QUANT} diff --git a/scripts/runner.py b/scripts/runner.py new file mode 100644 index 0000000..445c50b --- /dev/null +++ b/scripts/runner.py @@ -0,0 +1,57 @@ +import logging + +import glob +import hydra +import os +from omegaconf import DictConfig, OmegaConf + +from panza import PanzaWriter # The import also loads custom Hydra resolvers + +LOGGER = logging.getLogger(__name__) + + +def rename_config_keys(cfg: DictConfig) -> None: + # Disable struct mode to allow modifications + OmegaConf.set_struct(cfg, False) + + cfg.writer.llm.sampling_parameters = cfg.writer.llm.sampling + del cfg.writer.llm.sampling + + cfg.writer.prompt_builder = cfg.writer.prompting + del cfg.writer.prompting + + # Re-enable struct mode to lock down the configuration + OmegaConf.set_struct(cfg, True) + + +def set_latest_model(cfg: DictConfig) -> None: + model_files = glob.glob( + f"{cfg.checkpoint_dir}/models/*" + ) # * means all if need specific format then *.csv + latest_file = max(model_files, key=os.path.getctime) + + OmegaConf.set_struct(cfg, False) + cfg.checkpoint = latest_file + OmegaConf.set_struct(cfg, True) + + +@hydra.main(version_base="1.1", config_path="../configs", config_name="panza_writer") +def main(cfg: DictConfig) -> None: + LOGGER.info("Starting Panza Writer") + LOGGER.info("Configuration: \n%s", OmegaConf.to_yaml(cfg, resolve=True)) + + # Rename config keys to follow class structure + rename_config_keys(cfg) + # Find the latest checkpoint, if requested. + set_latest_model(cfg) + + # Instantiate Panza writer + writer: PanzaWriter = hydra.utils.instantiate(cfg.writer) + assert isinstance(writer, PanzaWriter), "Failed to instantiate PanzaWriter" + + # Instantiate interfaces (CLI, GUI, web, etc) as specified in the configuration + hydra.utils.instantiate(cfg.interfaces, writer=writer) + + +if __name__ == "__main__": + main() diff --git a/scripts/runner.sh b/scripts/runner.sh new file mode 100755 index 0000000..d59b467 --- /dev/null +++ b/scripts/runner.sh @@ -0,0 +1,31 @@ +# Convenience script for launching your fine-tuned model. +# All arguments to the python script can be provided +# here exactly in the form they would be passed to the +# python script directly. +# +# Example usage: +# CUDA_VISIBLE_DEVICES=x ./runner.sh user=USERNAME interfaces=cli writer/llm=transformers + +set -e + +vars=() +# Set a default for the required user argument. We'll override it +# later if provided. +vars[1]=$"user=$(whoami)" +idx=2 + +# process input arguments +for argument in "$@" +do + key=$(echo $argument | cut -f1 -d=) + + if [[ $key == user ]]; then + # We already set the default value here; change it now. + vars[1]=$argument + else + vars[idx]=$argument + idx+=1 + fi +done + +python3 runner.py ${vars[@]} \ No newline at end of file diff --git a/scripts/train_fft.sh b/scripts/train_fft.sh index 6938d7c..92f8327 100755 --- a/scripts/train_fft.sh +++ b/scripts/train_fft.sh @@ -1,162 +1,34 @@ -set -e - -source config.sh - -current_user=$(whoami) - -export DATA_PATH=${PANZA_DATA_DIR}/train.jsonl +# Convenience script for running full finetuning. +# All arguments to the python script can be provided +# here exactly in the form they would be passed to the +# python script directly. +# +# Example usage: +# ./train_fft.sh user=alonso trainer.optimizer.lr=0.1 -# hyper-parameters with default values -#export MODEL_PRECISION=bf16 # bf16 or fp32 -export BASE_SAVE_PATH=${PANZA_CHECKPOINTS} # where to store the model -export NUM_EPOCHS=3 -export WARMUP=20 # the learning rate warmup (batches) -export BS=8 -export PER_DEVICE_BS=1 -export SEED=${PANZA_SEED} - -if [[ ${MODEL_TYPE} == llama3 ]]; then - export LR=1e-5 # learning rate -elif [[ ${MODEL_TYPE} == mistralv2 ]]; then - export LR=1e-5 # learning rate -elif [[ ${MODEL_TYPE} == phi3 ]]; then - export LR=1e-5 # learning rate -else - echo "Model type ${MODEL_TYPE} not recognized! Panza only works with mistralv2, llama3 and phi3 models. Exiting." - exit -fi +set -e -export PRETRAINED=${PANZA_GENERATIVE_MODEL} -export CONFIG=${PANZA_FINETUNE_CONFIGS}/fft_panza.yaml +vars=() +# Set a default for the required user argument. We'll override it +# later if provided. +vars[1]=$"user=$(whoami)" +idx=2 -# take all the input arguments and put them in environment variables -# this could override the hyper-parameters defined above -for ARGUMENT in "$@" +# process input arguments +for argument in "$@" do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" + key=$(echo $argument | cut -f1 -d=) + + if [[ $key == user ]]; then + # We already set the default value here; change it now. + vars[1]=$argument + elif [[ $key == finetuning ]]; then + echo "The 'finetuning' argument is already set and should not be overridden here; override is ignored." + else + vars[idx]=$argument + idx+=1 + fi done -echo "Using Learning Rate ${LR} for ${MODEL_TYPE} model" - -export WANDB_PROJECT="panza-${PANZA_USERNAME}" - -if [ "$PANZA_FINETUNE_WITH_PREAMBLE" = 1 ]; then - PREAMBLE_STR="-PREAMBLE" - PREPROCESSING_FN=panza.finetuning.preprocessing:panza_preprocessing_function_train_with_preamble -elif [ "$PANZA_FINETUNE_WITH_THREAD" = 1 ]; then - PREAMBLE_STR="-THREAD" - PREPROCESSING_FN=panza.finetuning.preprocessing:panza_preprocessing_function_train_with_thread -else - PREAMBLE_STR="" - PREPROCESSING_FN=panza.finetuning.preprocessing:panza_preprocessing_function -fi - -if [ "$PANZA_FINETUNE_WITH_RAG" = 1 ]; then - RAFT_STR=-RAFT_num${PANZA_FINETUNE_RAG_NUM_EMAILS}_prob${PANZA_FINETUNE_RAG_PROB}_th${PANZA_FINETUNE_RAG_RELEVANCE_THRESHOLD} -else - RAFT_STR="" -fi - -# some post-processing on the inputs -export MAX_DURATION=${NUM_EPOCHS}ep -export RUN_NAME=panza_${PANZA_USERNAME}_${MODEL_TYPE}_${MODEL_PRECISION}-bs${BS}-fft-lr${LR}-epochs${NUM_EPOCHS}-wu${WARMUP}-seed${SEED}${PREAMBLE_STR}${RAFT_STR}-$RANDOM - -# create directories to save the models -mkdir -p ${BASE_SAVE_PATH}/models/ - -TEMP_FILE=$(mktemp) - -if [ "$MODEL_PRECISION" = "bf16" ]; then - export HF_SAVE_PRECISION=bfloat16 -elif [ "$MODEL_PRECISION" = "fp32" ]; then - export HF_SAVE_PRECISION=float32 -else - echo "Unknown model precision $MODEL_PRECISION" - exit 1 -fi - -export WANDB_DISABLED=${PANZA_WANDB_DISABLED} -TRAIN_SCRIPT=${PANZA_WORKSPACE}/src/panza/finetuning/train.py -composer ${TRAIN_SCRIPT} \ - ${CONFIG} \ - model_name_or_path=${PRETRAINED} \ - train_loader.dataset.hf_kwargs.data_files=${DATA_PATH} \ - train_loader.dataset.preprocessing_fn=${PREPROCESSING_FN} \ - max_duration=${MAX_DURATION} \ - run_name=${RUN_NAME} \ - optimizer.lr=${LR} \ - global_train_batch_size=${BS} \ - device_train_microbatch_size=${PER_DEVICE_BS} \ - device_eval_batch_size=${PER_DEVICE_BS} \ - scheduler.t_warmup=${WARMUP}ba \ - model.weight_bias_dtype=${MODEL_PRECISION} \ - global_seed=${SEED} \ - seed=${SEED} \ - callbacks.hf_checkpointer.precision=${HF_SAVE_PRECISION} \ - hf_save_path=${BASE_SAVE_PATH}/models/ 2>&1 | tee "$TEMP_FILE" - -# Extract the wandb run ID from the temp file -WANDB_RUN_ID=$(grep -o 'https://wandb.ai/[^ ]*/runs/[^ ]*' "$TEMP_FILE" | awk -F'/' '{print $NF}' | tail -n 1) - -rm "$TEMP_FILE" - -# move the checkpoint (saved by llm-foundry) to the correct directory -export RUN_SAVE_PATH=${BASE_SAVE_PATH}/models/${RUN_NAME} -export LAST_SAVE_DIR_NAME=$(ls -t ${RUN_SAVE_PATH}/huggingface | head -n 1) -mv ${RUN_SAVE_PATH}/huggingface/${LAST_SAVE_DIR_NAME}/* ${RUN_SAVE_PATH} -rm -rf ${RUN_SAVE_PATH}/huggingface - -echo "find the finetuned model at ${BASE_SAVE_PATH}/models/${RUN_NAME}" - -if [ -z "$WANDB_RUN_ID" ]; then - echo "No wandb run ID found." -else - echo "Extracted wandb run ID: $WANDB_RUN_ID" -fi - -# Running BLEU evaluation -EVAL_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/evaluation.py -python ${EVAL_SCRIPT} \ - --model=${BASE_SAVE_PATH}/models/${RUN_NAME} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --thread-preamble=${PANZA_THREAD_PREAMBLE_PATH} \ - --golden=${PANZA_DATA_DIR}/test.jsonl \ - --batch-size=${PANZA_EVALUATION_BATCH_SIZE} \ - --wandb-run-id=${WANDB_RUN_ID} - -# Running BLEU evaluation with thread -EVAL_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/evaluation.py -python ${EVAL_SCRIPT} \ - --model=${BASE_SAVE_PATH}/models/${RUN_NAME} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --thread-preamble=${PANZA_THREAD_PREAMBLE_PATH} \ - --golden=${PANZA_DATA_DIR}/test.jsonl \ - --batch-size=${PANZA_EVALUATION_BATCH_SIZE} \ - --wandb-run-id=${WANDB_RUN_ID} \ - --use-thread - -# Running BLEU evaluation with RAG -python ${EVAL_SCRIPT} \ - --model=${BASE_SAVE_PATH}/models/${RUN_NAME} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --thread-preamble=${PANZA_THREAD_PREAMBLE_PATH} \ - --golden=${PANZA_DATA_DIR}/test.jsonl \ - --batch-size=${PANZA_EVALUATION_BATCH_SIZE} \ - --wandb-run-id=${WANDB_RUN_ID} \ - --embedding-model=${PANZA_EMBEDDING_MODEL} \ - --db-path=${PANZA_DATA_DIR} \ - --index-name=${PANZA_USERNAME} \ - --use-rag - -echo "find the finetuned model at ${BASE_SAVE_PATH}/models/${RUN_NAME}" +composer ../src/panza/finetuning/train.py \ + finetuning=full ${vars[@]} \ No newline at end of file diff --git a/scripts/train_rosa.sh b/scripts/train_rosa.sh index d39d41c..b8c5997 100755 --- a/scripts/train_rosa.sh +++ b/scripts/train_rosa.sh @@ -1,245 +1,41 @@ -set -e - -source config.sh - -current_user=$(whoami) - -export DATA_PATH=${PANZA_DATA_DIR}/train.jsonl +# Convenience script for running RoSA finetuning. +# All arguments to the python script can be provided +# here exactly in the form they would be passed to the +# python script directly. +# +# Example usage: +# ./train_rosa.sh user=alonso trainer.optimizer.lr=0.1 -# hyper-parameters with default values -export MASK_GEN_MODEL_PRECISION=${MODEL_PRECISION} # bf16, fp32, or 4bit -export BASE_SAVE_PATH=${PANZA_CHECKPOINTS} # where to store the checkpoints and generated masks -export NUM_EPOCHS=5 -export WARMUP=8 # the learning rate warmup (batches) -export BS=8 -export PER_DEVICE_BS=1 -export LORA_ALPHA=16 -export SCHEDULE=wl16 # the RoSA schedule -export SPA_NUM_GRADS=1 # number of gradients used for mask generation -export SPA_GRAD_ACC_MODE=mean_squared # 'mean' or 'mean_squared': how to accumulate gradients -export SEED=${PANZA_SEED} - -if [[ ${MODEL_TYPE} == llama3 ]]; then - export LR=1e-5 # learning rate - export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters -elif [[ ${MODEL_TYPE} == mistralv2 ]]; then - export LR=1e-5 # learning rate - export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters -elif [[ ${MODEL_TYPE} == phi3 ]]; then - export LR=1e-5 # learning rate - export LORA_LR=1e-5 # a separate learning rate for the low-rank adapters -else - echo "Model type ${MODEL_TYPE} not recognized! Panza only works with mistralv2, llama3 and phi3 models. Exiting." - exit -fi - -# hyper-parameters without default values -export SPA_DENSITY=0.01 # the sparse adapters' density -export LORA_R=8 # the low-rank adapters' rank - -export PRETRAINED=${PANZA_GENERATIVE_MODEL} -export CONFIG=${PANZA_FINETUNE_CONFIGS}/rosa_panza.yaml -export NUM_CPU_THREADS=0 # useful for running of CPU, 0 means default the used by torch +set -e -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" # if not set, default to 0 +vars=() +# Set a default for the required user argument. We'll override it +# later if provided. +vars[1]=$"user=$(whoami)" +idx=2 -# take all the input arguments and put them in environment variables -# this could override the hyper-parameters defined above -for ARGUMENT in "$@" +# process input arguments +for argument in "$@" do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" + key=$(echo $argument | cut -f1 -d=) + + if [[ $key == user ]]; then + # We already set the default value here; change it now. + vars[1]=$argument + elif [[ $key == finetuning ]]; then + echo "The 'finetuning' argument is already set and should not be overridden here; override is ignored." + elif [[ $key == finetuning.rosa.masks_only ]]; then + echo "The 'finetuning.rosa.masks_only' argument is already set and should not be overridden here; override is ignored." + else + vars[idx]=$argument + idx+=1 + fi done -echo "Using Learning Rate ${LR} and LoRA LR ${LORA_LR} for ${MODEL_TYPE} model" - -export WANDB_PROJECT="panza-${PANZA_USERNAME}" - -if [ "$PANZA_FINETUNE_WITH_PREAMBLE" = 1 ]; then - PREAMBLE_STR="-PREAMBLE" - PREPROCESSING_FN=panza.finetuning.preprocessing:panza_preprocessing_function_train_with_preamble -else - PREAMBLE_STR="" - PREPROCESSING_FN=panza.finetuning.preprocessing:panza_preprocessing_function -fi - -if [ "$PANZA_FINETUNE_WITH_RAG" = 1 ]; then - RAFT_STR=-RAFT_num${PANZA_FINETUNE_RAG_NUM_EMAILS}_prob${PANZA_FINETUNE_RAG_PROB}_th${PANZA_FINETUNE_RAG_RELEVANCE_THRESHOLD} -else - RAFT_STR="" -fi - -# some post-processing on the inputs -export MAX_DURATION=${NUM_EPOCHS}ep -export RUN_NAME=panza_${PANZA_USERNAME}_${MODEL_TYPE}_${MODEL_PRECISION}-bs${BS}-rosa_${SCHEDULE}_d${SPA_DENSITY}_${SPA_NUM_GRADS}grads_${SPA_GRAD_ACC_MODE}_r${LORA_R}_loralr${LORA_LR}_alpha${LORA_ALPHA}-lr${LR}-epochs${NUM_EPOCHS}-wu${WARMUP}-seed${SEED}${PREAMBLE_STR}${RAFT_STR}-$RANDOM - -# create directories to save the masks and models -mkdir -p ${BASE_SAVE_PATH}/masks/ -mkdir -p ${BASE_SAVE_PATH}/models/ - -if [ "$MODEL_PRECISION" = "bf16" ]; then - export ROSA_DTYPE=bf16 -elif [ "$MODEL_PRECISION" = "4bit" ]; then - export ROSA_DTYPE=fp32 -elif [ "$MODEL_PRECISION" = "fp32" ]; then - export ROSA_DTYPE=fp32 -else - echo "Unknown model precision $MODEL_PRECISION" - exit 1 -fi - -if [[ "$SPA_DENSITY" != "0" ]] -then - # sparse adaptation exists, so we need to generate masks - - if [[ $LORA_R == 0 ]] - then - export SCHEDULE=spa_only - fi - - # no wandb logging for mask generation - export WANDB_DISABLED=True - - # generate the masks and terminate - TRAIN_SCRIPT=${PANZA_WORKSPACE}/src/panza/finetuning/train.py - composer ${TRAIN_SCRIPT} \ - ${CONFIG} \ - model_name_or_path=${PRETRAINED} \ - num_cpu_threads=${NUM_CPU_THREADS} \ - train_loader.dataset.hf_kwargs.data_files=${DATA_PATH} \ - train_loader.dataset.preprocessing_fn=${PREPROCESSING_FN} \ - max_duration=${MAX_DURATION} \ - run_name=${RUN_NAME} \ - optimizer.lr=${LR} \ - global_train_batch_size=${BS} \ - device_train_microbatch_size=${PER_DEVICE_BS} \ - device_eval_batch_size=${PER_DEVICE_BS} \ - scheduler.t_warmup=${WARMUP}ba \ - model.weight_bias_dtype=${MASK_GEN_MODEL_PRECISION} \ - rosa.spa_d=${SPA_DENSITY} \ - rosa.spa_num_grads=${SPA_NUM_GRADS} \ - rosa.grad_acc_mode=${SPA_GRAD_ACC_MODE} \ - rosa.lora_r=${LORA_R} \ - rosa.lora_alpha=${LORA_ALPHA} \ - rosa.lora_lr=${LORA_LR} \ - rosa.schedule=${SCHEDULE} \ - rosa.rosa_dtype=${ROSA_DTYPE} \ - global_seed=${SEED} \ - seed=${SEED} \ - hf_save_path=${BASE_SAVE_PATH}/models/ \ - rosa.mask_save_path=${BASE_SAVE_PATH}/masks/${RUN_NAME} \ - rosa.terminate_after_mask_generation=true -fi - -# now we have the masks ready, so let's restart -export MASK_LOAD_PATH=${BASE_SAVE_PATH}/masks/${RUN_NAME} - -# determine the correct RoSA schedule -if [[ "$SPA_DENSITY" != "0" && $LORA_R -ne 0 ]] -then - export SCHEDULE=default -elif [[ $LORA_R -ne 0 ]] -then - export SCHEDULE=lora_only - export MASK_LOAD_PATH= -else - export SCHEDULE=spa_only -fi - -TEMP_FILE=$(mktemp) - -export WANDB_DISABLED=${PANZA_WANDB_DISABLED} -# start the training with both sparse and low-rank adapters active from the outset -TRAIN_SCRIPT=${PANZA_WORKSPACE}/src/panza/finetuning/train.py -composer ${TRAIN_SCRIPT} \ - ${CONFIG} \ - model_name_or_path=${PRETRAINED} \ - num_cpu_threads=${NUM_CPU_THREADS} \ - train_loader.dataset.hf_kwargs.data_files=${DATA_PATH} \ - train_loader.dataset.preprocessing_fn=${PREPROCESSING_FN} \ - max_duration=${MAX_DURATION} \ - run_name=${RUN_NAME} \ - optimizer.lr=${LR} \ - global_train_batch_size=${BS} \ - device_train_microbatch_size=${PER_DEVICE_BS} \ - device_eval_batch_size=${PER_DEVICE_BS} \ - scheduler.t_warmup=${WARMUP}ba \ - model.weight_bias_dtype=${MODEL_PRECISION} \ - rosa.spa_d=${SPA_DENSITY} \ - rosa.spa_num_grads=${SPA_NUM_GRADS} \ - rosa.grad_acc_mode=${SPA_GRAD_ACC_MODE} \ - rosa.lora_r=${LORA_R} \ - rosa.lora_alpha=${LORA_ALPHA} \ - rosa.lora_lr=${LORA_LR} \ - rosa.schedule=${SCHEDULE} \ - rosa.rosa_dtype=${ROSA_DTYPE} \ - global_seed=${SEED} \ - seed=${SEED} \ - hf_save_path=${BASE_SAVE_PATH}/models/ \ - rosa.mask_load_path=${MASK_LOAD_PATH} 2>&1 | tee "$TEMP_FILE" - -# Extract the wandb run ID from the temp file -WANDB_RUN_ID=$(grep -o 'https://wandb.ai/[^ ]*/runs/[^ ]*' "$TEMP_FILE" | awk -F'/' '{print $NF}' | tail -n 1) - -# Clean up -rm "$TEMP_FILE" -rm -rf "$MASK_LOAD_PATH" - -echo "find the adapter at ${BASE_SAVE_PATH}/models/${RUN_NAME}" - -USE_4BIT_QUANT=$([ "${MODEL_PRECISION}" = "4bit" ] && echo "--load-in-4bit" || echo "") - -if [ -z "$WANDB_RUN_ID" ]; then - echo "No wandb run ID found." -else - echo "Extracted wandb run ID: $WANDB_RUN_ID" -fi - -# Running BLEU evaluation -EVAL_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/evaluation.py -python ${EVAL_SCRIPT} \ - --model=${BASE_SAVE_PATH}/models/${RUN_NAME} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --thread-preamble=${PANZA_THREAD_PREAMBLE_PATH} \ - --golden=${PANZA_DATA_DIR}/test.jsonl \ - --batch-size=${PANZA_EVALUATION_BATCH_SIZE} \ - --wandb-run-id=${WANDB_RUN_ID} \ - ${USE_4BIT_QUANT} - -# Running BLEU evaluation with thread -EVAL_SCRIPT=${PANZA_WORKSPACE}/src/panza/evaluation/evaluation.py -python ${EVAL_SCRIPT} \ - --model=${BASE_SAVE_PATH}/models/${RUN_NAME} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --thread-preamble=${PANZA_THREAD_PREAMBLE_PATH} \ - --golden=${PANZA_DATA_DIR}/test.jsonl \ - --batch-size=${PANZA_EVALUATION_BATCH_SIZE} \ - --wandb-run-id=${WANDB_RUN_ID} \ - --use-thread \ - ${USE_4BIT_QUANT} - -# Running BLEU evaluation with RAG -python ${EVAL_SCRIPT} \ - --model=${BASE_SAVE_PATH}/models/${RUN_NAME} \ - --system-preamble=${PANZA_SYSTEM_PREAMBLE_PATH} \ - --user-preamble=${PANZA_USER_PREAMBLE_PATH} \ - --rag-preamble=${PANZA_RAG_PREAMBLE_PATH} \ - --thread-preamble=${PANZA_THREAD_PREAMBLE_PATH} \ - --golden=${PANZA_DATA_DIR}/test.jsonl \ - --batch-size=${PANZA_EVALUATION_BATCH_SIZE} \ - --wandb-run-id=${WANDB_RUN_ID} \ - --embedding-model=${PANZA_EMBEDDING_MODEL} \ - --db-path=${PANZA_DATA_DIR} \ - --index-name=${PANZA_USERNAME} \ - --use-rag \ - ${USE_4BIT_QUANT} +# First create the masks for RoSA finetuning. +composer ../src/panza/finetuning/train.py \ + finetuning=rosa finetuning.rosa.masks_only=true ${vars[@]} -echo "find the adapter at ${BASE_SAVE_PATH}/models/${RUN_NAME}" +# Then train the weights. +composer ../src/panza/finetuning/train.py \ + finetuning=rosa finetuning.rosa.masks_only=false ${vars[@]} diff --git a/src/panza/__init__.py b/src/panza/__init__.py new file mode 100644 index 0000000..bdee682 --- /dev/null +++ b/src/panza/__init__.py @@ -0,0 +1,55 @@ +from omegaconf import OmegaConf + +from .prompting.utils import load_preamble, load_user_preamble + +OmegaConf.register_new_resolver("load_preamble", load_preamble) +OmegaConf.register_new_resolver("load_user_preamble", load_user_preamble) + +from .writer import PanzaWriter + +__all__ = ["PanzaWriter"] + +PANZA_ASCII_LOGO = """ . . . . . . + ... . . . . . . . + . . . . . =%[ :+. . + . . . .~@% +@( + . .. ~<: . . . >@^.@) . + . . :}: . . *@@@@@{^=-. + . ={@@~ . . . .)@@@@@@@@@#= . + *%@@@@^ . . ~^^ >@@@@@@@@@%(=. + (@%@@@@@@@@[ =#@@@@@@@@@@@#= -}@@@@@@]..*= + ^@@@@@@@@@@%.. :<#@@@@@@@@@@@{= ..:}@@@@@-) .. + ~@@@@@@@@@@@-:[@@@@@@@@@@@@@@%+ . =#@@@@@+* .. . + :@@@@@@#^:.^*>#@%#{@@@@@@@@@@@# ..+-.^@@@@@@<-- + .}@@@[+ -( . .<@@@@@@@@@@^ ^%[=)@@@@@@~<( + . .. +@*. *}@@@@@@@@@@@@()>+)@@+=%@@@@@+}{= . + {@@#%@@@@@@@@@@@@@@@%{@@@%+]@@@@@@+ + . =@@@@@@@@@@@@@@@@@@@@@{#{*.(@@@@@@* + . -@][[#@@@@@@@@@@@@@@@@@).=#@@@@@@* + +)}[{}* #..-@@@@@@@@@@@@@@@@@@%>(@@@@@@@^ + . ~[@@@@@@^:=#{#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@<. . + ~#@@@@@@#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@% + .:)@@@@(:~%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@* + ~{@#<. <@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#.. .. + == (@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@> . . + . .>@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@] . + . :%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@(. + . ^@@@@@@@@@@@@@@@@@@@@@@@@@@@@@~ + ... ~@@@@@@@@@@@@@@@@@@@@@@@@@[~ . . + . . . >@@@@@@@}=:=+^>^+@@@@@@@[: . . +. ..>@@@@@}- +@@@@@]: . .. +. :#<<( -{](> . + . . ..(]>}+~~~=======+}(((~-::... . .. + .:~=+^>)][[}{#%%@@@@@@@@@@@@@@@@@%%#{}}[])<>*+=-:. . + . . . . . .. . . . """ + +PANZA_ASCII_TEXT = """.______ ___ .__ __. ________ ___ +| _ \ / \ | \ | | | / / \ +| |_) | / ^ \ | \| | `---/ / / ^ \ +| ___/ / /_\ \ | . ` | / / / /_\ \ +| | / _____ \ | |\ | / /----./ _____ \ +| _| /__/ \__\ |__| \__| /________/__/ \__\ + """ + +print(PANZA_ASCII_LOGO) +print(PANZA_ASCII_TEXT) diff --git a/src/panza/data_preparation/create_vector_store.py b/src/panza/data_preparation/create_vector_store.py deleted file mode 100644 index 86c4c7f..0000000 --- a/src/panza/data_preparation/create_vector_store.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -import json -import time -from typing import List - -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_core.documents import Document - -from panza.utils import rag -from panza.utils.documents import Email - - - -def load_emails(path: str) -> List[Email]: - with open(path, "r") as f: - lines = f.readlines() - - emails = [Email.deserialize(line) for line in lines] - - return emails - - -def process_emails(emails: List[Email], chunk_size: int, chunk_overlap: int) -> List[Document]: - # Convert e-mails to langchain documents - documents = [ - Document(page_content=email.email, metadata={"serialized_email": email.serialize()}) - for email in emails - ] - - # Split long e-mails into text chuncks - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, chunk_overlap=chunk_overlap - ) - documents = text_splitter.split_documents(documents) - - return documents - - -def main(): - parser = argparse.ArgumentParser() - - parser = argparse.ArgumentParser(description="Store emails in a embeddings vector DB.") - parser.add_argument("--path-to-emails", help="Path to the cleaned emails") - parser.add_argument("--chunk-size", type=int, default=3000) - parser.add_argument("--chunk-overlap", type=int, default=3000) - parser.add_argument("--db-path", type=str) - parser.add_argument("--index-name", type=str) - parser.add_argument( - "--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2" - ) - - args = parser.parse_args() - - # Load emails - emails = load_emails(args.path_to_emails) - print(f"Loaded {len(emails)} emails.") - - # Process emails - documents = process_emails(emails, args.chunk_size, args.chunk_overlap) - print(f"Obtained {len(documents)} text chuncks.") - - # Initialize embeddings model - embeddings_model = rag.get_embeddings_model(args.embedding_model) - - # Create vector DB - print("Creating vector DB...") - start = time.time() - db = rag.create_vector_db(documents, embeddings_model) - print(f"Vector DB created in {time.time() - start} seconds.") - - # Save vector DB to disk - db.save_local(folder_path=args.db_path, index_name=args.index_name) - print(f"Vector DB index {args.index_name} saved to {args.db_path}.") - - -if __name__ == "__main__": - main() diff --git a/src/panza/data_preparation/extract_emails.py b/src/panza/data_preparation/extract_emails.py index 12f1d79..e9ead5e 100644 --- a/src/panza/data_preparation/extract_emails.py +++ b/src/panza/data_preparation/extract_emails.py @@ -1,10 +1,11 @@ -import argparse import json import mailbox import re from email.utils import parsedate_to_datetime +from email.message import Message +from mailbox import mboxMessage from os import makedirs -from os.path import join +from os.path import join, dirname import langdetect @@ -19,6 +20,8 @@ SHORT_EMAIL_THRESHOLD = 10 # words +FORWARDED_MESSAGE_TAG = "---------- Forwarded message ---------" + def extract_only_plain_text(msg_part): if msg_part.get_content_type() == "text/plain": @@ -28,7 +31,7 @@ def extract_only_plain_text(msg_part): def skip_forwarded_messages(plain_text): - if "---------- Forwarded message ---------" in plain_text: + if FORWARDED_MESSAGE_TAG in plain_text: DISCARDED_EMAILS["forwarded"].append(plain_text) return "" else: @@ -42,7 +45,7 @@ def remove_date_time(email_body): match = pattern.search(email_body) if match: - return (email_body[:match.start()] + email_body[match.end():]).strip() + return (email_body[: match.start()] + email_body[match.end() :]).strip() else: return email_body @@ -61,17 +64,17 @@ def count_words(s): def extract_by_quote_level(text): # Split the text into lines - lines = text.split('\n') + lines = text.split("\n") # Dictionary to store lines by quote level grouped_lines = {} for line in lines: # Count the number of '>' at the start of the line - quote_level = len(re.match(r'^>*', line).group()) + quote_level = len(re.match(r"^>*", line).group()) # Remove leading '>' and spaces - clean_line = re.sub(r'^>*\s*', '', line) + clean_line = re.sub(r"^>*\s*", "", line) # Add the clean line to the appropriate group if quote_level not in grouped_lines: @@ -99,7 +102,7 @@ def filter_message(msg): email_with_thread = [remove_date_time(an_email) for an_email in email_with_thread] main_email = email_with_thread.pop(0) - email_with_thread.reverse() # chronological order + email_with_thread.reverse() # chronological order # check length before detecting language if count_words(main_email) < SHORT_EMAIL_THRESHOLD: @@ -121,20 +124,10 @@ def filter_message(msg): return (main_email.strip(), [an_email.strip() for an_email in email_with_thread]) -def main(): - parser = argparse.ArgumentParser(description="Process an MBOX file for PANZA project.") - parser.add_argument("--mbox-path", help="Path to the MBOX file.") - parser.add_argument("--output-path", help="Path to the directory to save the output files.") - parser.add_argument( - "--email", - action="append", - help="Email address(es) to filter the messages. Use the argument multiple times for multiple emails.", - ) - parser.add_argument("--save-discarded-emails", action="store_true") - args = parser.parse_args() +def extract_emails(mailbox_path, output_path, email_addresses, save_discarded_emails_path): - MBOX_PATH = args.mbox_path - EMAIL = args.email + MBOX_PATH = mailbox_path + EMAIL = email_addresses mbox = mailbox.mbox(MBOX_PATH) n_emails = len(mbox) @@ -142,20 +135,38 @@ def main(): print(f"--> processing {i}/{n_emails} <--") # Filter messages sent from your email address if message["from"] and any(email in message["from"] for email in EMAIL): - date = parsedate_to_datetime(message["Date"]).isoformat() + if message["Date"]: + date = parsedate_to_datetime(message["Date"]).isoformat() + else: + print("Date was not found in the email. Skipping.") + continue if message.is_multipart(): for part in message.walk(): filtered_msg = filter_message(part) if filtered_msg is not None: print(filtered_msg) main_email, thread = filtered_msg - CLEAN_EMAILS.append({"email": main_email, "thread": thread, "subject": message["Subject"], "date": date}) + CLEAN_EMAILS.append( + { + "email": main_email, + "thread": thread, + "subject": message["Subject"], + "date": date, + } + ) else: filtered_msg = filter_message(message) if filtered_msg is not None: print(filtered_msg) main_email, thread = filtered_msg - CLEAN_EMAILS.append({"email": main_email, "thread": thread, "subject": message["Subject"], "date": date}) + CLEAN_EMAILS.append( + { + "email": main_email, + "thread": thread, + "subject": message["Subject"], + "date": date, + } + ) print(f"\n---> [Cleaning stats] <---") print(f"# clean emails = {len(CLEAN_EMAILS)}") @@ -171,26 +182,27 @@ def main(): first_email = EMAIL[0] username = first_email[: first_email.find("@")] - makedirs(args.output_path, exist_ok=True) + makedirs(dirname(output_path), exist_ok=True) # Save clean emails - with open(join(args.output_path, username + "_clean.jsonl"), "w", encoding="utf-8") as f: + with open(join(output_path), "w", encoding="utf-8") as f: for item in CLEAN_EMAILS: json_record = json.dumps(item) f.write(json_record + "\n") # Save discarded emails - if args.save_discarded_emails: - makedirs(join(args.output_path, "discarded"), exist_ok=True) + if save_discarded_emails_path and save_discarded_emails_path != "": + print(f"\n---> Processing Discarded Emails <---") + makedirs(save_discarded_emails_path, exist_ok=True) for k, v in DISCARDED_EMAILS.items(): - output_path = join( - args.output_path, "discarded", username + "_discarded_" + k + ".jsonl" - ) + print(f"--> processing {k} emails <--") + output_path = join(save_discarded_emails_path, f"{username}_discarded_{k}.jsonl") with open(output_path, "w", encoding="utf-8") as f: - for item in v: + discarded_emails = len(v) + for i, item in enumerate(v): + print("\n\n\n\n\===========================") + if type(item) is Message or type(item) is mboxMessage: + item = item.get_payload() + print(f"--> processing {i}/{discarded_emails} <--") json_record = json.dumps(item) f.write(json_record + "\n") - - -if __name__ == "__main__": - main() diff --git a/src/panza/data_preparation/rag.py b/src/panza/data_preparation/rag.py new file mode 100644 index 0000000..ee03367 --- /dev/null +++ b/src/panza/data_preparation/rag.py @@ -0,0 +1,120 @@ +import copy +import json +import time +from abc import ABC +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Dict, List, Optional, Union + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore +from langchain.text_splitter import RecursiveCharacterTextSplitter + + +@dataclass(kw_only=True) +class Email(ABC): + email: str + subject: str + thread: List[str] + summary: Optional[str] = None + date: datetime + + def serialize(self) -> dict: + dictionary = asdict(self) + dictionary["date"] = self.date.isoformat() + return dictionary + + @classmethod + def deserialize(cls, data: Union[str, Dict]) -> "Email": + if isinstance(data, str): + dictionary = json.loads(data) + elif isinstance(data, dict): + dictionary = copy.deepcopy(data) + else: + raise ValueError(f"Cannot deserialize data of type {type(data)}. Must be str or dict.") + dictionary["date"] = datetime.fromisoformat(dictionary["date"]) + return cls(**dictionary) + + +def get_embeddings_model(model_name) -> Embeddings: + embeddings_model = HuggingFaceEmbeddings( + model_name=model_name, + model_kwargs={"device": "cpu"}, + encode_kwargs={"normalize_embeddings": False}, + ) + return embeddings_model + + +def create_vector_db(docs: List[Document], embeddings_model: Embeddings) -> VectorStore: + db = FAISS.from_documents(docs, embeddings_model) + return db + + +def load_vector_db_from_disk( + folder_path: str, index_name: str, embeddings_model: Embeddings +) -> VectorStore: + try: + db = FAISS.load_local( + folder_path=folder_path, + embeddings=embeddings_model, + index_name=index_name, + allow_dangerous_deserialization=True, # Allows pickle deserialization + ) + print("Faiss index loaded ") + return db + except Exception as e: + print("FAISS index loading failed \n", e) + + +def load_emails(path: str) -> List[Email]: + with open(path, "r") as f: + lines = f.readlines() + + emails = [Email.deserialize(line) for line in lines] + + return emails + + +def process_emails(emails: List[Email], chunk_size: int, chunk_overlap: int) -> List[Document]: + # Convert e-mails to langchain documents + documents = [ + Document(page_content=email.email, metadata={"serialized_email": email.serialize()}) + for email in emails + ] + + # Split long e-mails into text chunks + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + documents = text_splitter.split_documents(documents) + + return documents + + +def create_vector_store( + path_to_emails, chunk_size, chunk_overlap, db_path, index_name, embedding_model +): + """Create FAISS vector database for search and retrieval.""" + # Load emails + emails = load_emails(path_to_emails) + print(f"Loaded {len(emails)} emails.") + + # Process emails + documents = process_emails(emails, chunk_size, chunk_overlap) + print(f"Obtained {len(documents)} text chunks.") + + # Initialize embeddings model + embeddings_model = get_embeddings_model(embedding_model) + + # Create vector DB + print("Creating vector DB...") + start = time.time() + db = create_vector_db(documents, embeddings_model) + print(f"Vector DB created in {time.time() - start} seconds.") + + # Save vector DB to disk + db.save_local(folder_path=db_path, index_name=index_name) + print(f"Vector DB index {index_name} saved to {db_path}.") diff --git a/src/panza/entities/__init__.py b/src/panza/entities/__init__.py new file mode 100644 index 0000000..99266aa --- /dev/null +++ b/src/panza/entities/__init__.py @@ -0,0 +1,4 @@ +from .document import Document, Email +from .instruction import EmailInstruction, Instruction, SummarizationInstruction + +__all__ = ["Document", "Email", "EmailInstruction", "Instruction", "SummarizationInstruction"] diff --git a/src/panza/entities/document.py b/src/panza/entities/document.py new file mode 100644 index 0000000..6bce09b --- /dev/null +++ b/src/panza/entities/document.py @@ -0,0 +1,75 @@ +import copy +import json +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Dict, List, Optional, Union + +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_core.documents import Document as LangchainDocument + + +@dataclass +class Document(ABC): + summary: Optional[str] = None + + @abstractmethod + def serialize(self) -> dict: + """Convert the document to a dictionary that can be serialized to JSON.""" + pass + + @classmethod + @abstractmethod + def deserialize(cls, data: Union[str, Dict]) -> "Document": + """Convert a serialized document into a Document object.""" + pass + + @staticmethod + @abstractmethod + def process( + documents: List["Document"], chunk_size: int, chunk_overlap: int + ) -> List[LangchainDocument]: + """Prepare documents for storage.""" + pass + + +@dataclass(kw_only=True) +class Email(Document): + email: str + subject: str + thread: List[str] = field(default_factory=list) + date: datetime + + def serialize(self) -> dict: + dictionary = asdict(self) + dictionary["date"] = self.date.isoformat() + return dictionary + + @classmethod + def deserialize(cls, data: Union[str, Dict]) -> "Email": + if isinstance(data, str): + dictionary = json.loads(data) + elif isinstance(data, dict): + dictionary = copy.deepcopy(data) + else: + raise ValueError(f"Cannot deserialize data of type {type(data)}. Must be str or dict.") + dictionary["date"] = datetime.fromisoformat(dictionary["date"]) + return cls(**dictionary) + + @staticmethod + def process(documents: List["Email"], chunk_size, chunk_overlap) -> List[Document]: + # Convert e-mails to langchain documents + documents = [ + LangchainDocument( + page_content=email.email, metadata={"serialized_document": email.serialize()} + ) + for email in documents + ] + + # Split long e-mails into text chuncks + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + documents = text_splitter.split_documents(documents) + + return documents diff --git a/src/panza/entities/instruction.py b/src/panza/entities/instruction.py new file mode 100644 index 0000000..f622329 --- /dev/null +++ b/src/panza/entities/instruction.py @@ -0,0 +1,21 @@ +from abc import ABC +from dataclasses import dataclass, field +from typing import List + +from ..llm import ChatHistoryType + + +@dataclass +class Instruction(ABC): + instruction: str + past_messages: ChatHistoryType = field(default_factory=list) + + +@dataclass(kw_only=True) +class EmailInstruction(Instruction): + thread: List[str] = field(default_factory=list) + + +@dataclass(kw_only=True) +class SummarizationInstruction(Instruction): + pass diff --git a/src/panza/evaluation/base_inference.py b/src/panza/evaluation/base_inference.py deleted file mode 100644 index adc79bb..0000000 --- a/src/panza/evaluation/base_inference.py +++ /dev/null @@ -1,174 +0,0 @@ -import argparse -import os -import sys - -import torch -from peft import AutoPeftModelForCausalLM -from transformers import (AutoModelForCausalLM, AutoTokenizer, - BitsAndBytesConfig) - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - -from panza.utils import prompting -from panza.utils.documents import Email - -sys.path.pop(0) - - -def get_base_inference_args_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument("--model", default=None) - parser.add_argument("--system-preamble", type=str, default=None) - parser.add_argument("--user-preamble", type=str, default=None) - parser.add_argument("--rag-preamble", type=str, default=None) - parser.add_argument("--thread-preamble", type=str, default=None) - parser.add_argument("--best", action="store_true", default=False) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--top-k", type=int, default=50) - parser.add_argument("--top-p", type=float, default=0.7) - parser.add_argument("--max-new-tokens", type=int, default=1024) - parser.add_argument("--use-rag", action="store_true", default=False) - parser.add_argument("--rag-relevance-threshold", type=float, default=0.2) - parser.add_argument( - "--embedding-model", type=str, default="sentence-transformers/all-mpnet-base-v2" - ) - parser.add_argument("--db-path", type=str, default=None) - parser.add_argument("--index-name", type=str, default=None) - parser.add_argument("--rag-num-emails", type=int, default=7) - parser.add_argument("--device", type=str, default="cuda:0") - parser.add_argument("--dtype", type=str, default="bf16") - parser.add_argument("--nthreads", type=int, default=None) - parser.add_argument("--load-in-4bit", default=False, action="store_true") - - return parser - - -def load_model_and_tokenizer(model_path, device, dtype, load_in_4bit): - assert dtype in [None, "fp32", "bf16"] - if device == "cpu": - assert dtype == "fp32", "CPU only supports fp32, please specify --dtype fp32" - dtype = None if dtype is None else (torch.float32 if dtype == "fp32" else torch.bfloat16) - - quant_config = ( - BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=dtype, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - ) - if load_in_4bit - else None - ) - - if os.path.exists(os.path.join(model_path, "adapter_config.json")): - print("found an adapter.") - if load_in_4bit: - model = AutoPeftModelForCausalLM.from_pretrained( - model_path, device_map=device, quantization_config=quant_config, trust_remote_code=True - ) - else: - model = AutoPeftModelForCausalLM.from_pretrained( - model_path, torch_dtype=dtype, device_map=device, trust_remote_code=True - ) - model = model.merge_and_unload() - else: - if load_in_4bit: - model = AutoModelForCausalLM.from_pretrained( - model_path, device_map=device, quantization_config=quant_config, trust_remote_code=True - ) - else: - model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=dtype, device_map=device, trust_remote_code=True - ) - - tokenizer = AutoTokenizer.from_pretrained( - model_path, model_max_length=model.config.max_position_embeddings - ) - tokenizer.padding_side = "left" - tokenizer.pad_token = tokenizer.eos_token - - return model, tokenizer - - -def run_inference( - instructions, - model, - tokenizer, - system_preamble, - user_preamble, - rag_preamble, - rag_relevance_threshold, - rag_num_emails, - thread_preamble, - use_rag, - db, - max_new_tokens, - best, - temperature, - top_k, - top_p, - device, -): - batch = [] - prompts = [] - for instruction, thread in instructions: - relevant_emails = [] - if use_rag: - assert db is not None, "RAG requires a database to be provided." - re = db._similarity_search_with_relevance_scores( - instruction, k=rag_num_emails - ) - relevant_emails = [ - Email.deserialize(r[0].metadata["serialized_email"]) - for r in re - if r[1] >= rag_relevance_threshold - ] - - prompt = prompting.create_prompt( - instruction, system_preamble, user_preamble, rag_preamble, relevant_emails, thread_preamble, thread, - ) - prompts.append(prompt) - messages = [{"role": "user", "content": prompt}] - batch.append(messages) - - encodeds = tokenizer.apply_chat_template( - batch, - return_tensors="pt", - add_generation_prompt=True, - padding=True, - truncation=True, - return_dict=True, - ) - model_inputs = encodeds.to(device) - - if best: - generated_ids = model.generate( - **model_inputs, - max_new_tokens=max_new_tokens, - do_sample=False, - num_beams=1, - pad_token_id=tokenizer.pad_token_id, - ) - else: - generated_ids = model.generate( - **model_inputs, - max_new_tokens=max_new_tokens, - do_sample=True, - temperature=temperature, - top_k=top_k, - top_p=top_p, - pad_token_id=tokenizer.pad_token_id, - ) - - outputs = tokenizer.batch_decode(generated_ids) - - # Clean outputs - _, prompt_end_wrapper, _, response_end_wrapper = prompting.get_model_special_tokens( - model.name_or_path - ) - outputs = [ - output.split(prompt_end_wrapper)[-1].split(response_end_wrapper)[0] for output in outputs - ] - - return prompts, outputs diff --git a/src/panza/evaluation/console_interactive_inference.py b/src/panza/evaluation/console_interactive_inference.py deleted file mode 100644 index 92bbd38..0000000 --- a/src/panza/evaluation/console_interactive_inference.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import sys - -import torch - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - -from panza.evaluation import base_inference -from panza.utils import prompting, rag - -sys.path.pop(0) - - -def main(): - parser = base_inference.get_base_inference_args_parser() - args = parser.parse_args() - - print("Running inference with args:", args) - - if args.nthreads is not None: - torch.set_num_threads(args.nthreads) - - print("Loading model ", args.model) - model, tokenizer = base_inference.load_model_and_tokenizer(args.model, args.device, args.dtype, load_in_4bit=args.load_in_4bit) - - if args.use_rag: - embeddings_model = rag.get_embeddings_model(args.embedding_model) - db = rag.load_vector_db_from_disk(args.db_path, args.index_name, embeddings_model) - - system_preamble, user_preamble, rag_preamble, _ = prompting.load_all_preambles( - args.system_preamble, args.user_preamble, args.rag_preamble, args.thread_preamble - ) - - while True: - user_input = input("Enter another request (or 'quit' to exit): ") - - if user_input.lower() == "quit": - print("Exiting...") - break - - prompts, outputs = base_inference.run_inference( - instructions=[(user_input, None)], - model=model, - tokenizer=tokenizer, - system_preamble=system_preamble, - user_preamble=user_preamble, - rag_preamble=rag_preamble, - rag_relevance_threshold=args.rag_relevance_threshold, - rag_num_emails=args.rag_num_emails, - thread_preamble=None, - use_rag=args.use_rag, - db=db if args.use_rag else None, - max_new_tokens=args.max_new_tokens, - best=args.best, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - device=args.device, - ) - - print("Processed input:", prompts[0]) - print("Generated email", outputs[0]) - - -if __name__ == "__main__": - main() diff --git a/src/panza/evaluation/evaluation.py b/src/panza/evaluation/evaluation.py deleted file mode 100644 index 5882f65..0000000 --- a/src/panza/evaluation/evaluation.py +++ /dev/null @@ -1,194 +0,0 @@ -# We conduct evaluations with three scores. -# The BLEU score is frequently used to evaluate translations and compares n-grams in a 'golden' -# translation to those in a candidate translation. Multiple golden translations are possible. -# The ROUGE score is frequently used for translation and summarization; it also looks at -# n-gram similarity. It is actually several scores, since precision, recall, and F1 score are -# reported separately. -# The MAUVE score measures distribution similarity (in the sense of KL-divergence) between the -# targets and outputs, and is not computed on a per-example basis. The similarity is computed -# in the latent space of an LLM, by default GPT-2. - -import json -import os -import re -import string -import sys - -import numpy as np -import torch -import wandb -from evaluate import load -from torchmetrics.text.bleu import BLEUScore -from torchmetrics.text.rouge import ROUGEScore - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - -from panza.evaluation import base_inference -from panza.utils import prompting, rag - -sys.path.pop(0) - - -def main(): - parser = base_inference.get_base_inference_args_parser() - parser.add_argument("--responses-per-prompt", type=int, default=1) - parser.add_argument("--golden", type=str, default=None) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--use-thread", action="store_true", default=False) - parser.add_argument("--wandb-run-id", type=str, default=None) - args = parser.parse_args() - - rouge = ROUGEScore() - # This library computes the BLEU score components separately. We do not use a length penalty. - bleu1 = BLEUScore(n_gram=1) - bleu2 = BLEUScore(n_gram=2) - bleu3 = BLEUScore(n_gram=3) - bleu4 = BLEUScore(n_gram=4) - mauve = load('mauve') - - if args.nthreads is not None: - torch.set_num_threads(args.nthreads) - - print("Loading model ", args.model) - model, tokenizer = base_inference.load_model_and_tokenizer(args.model, args.device, args.dtype, load_in_4bit=args.load_in_4bit) - - if args.use_rag: - embeddings_model = rag.get_embeddings_model(args.embedding_model) - db = rag.load_vector_db_from_disk(args.db_path, args.index_name, embeddings_model) - - system_preamble, user_preamble, rag_preamble, thread_preamble = prompting.load_all_preambles( - args.system_preamble, args.user_preamble, args.rag_preamble, args.thread_preamble - ) - - with open(args.golden, "r") as f: - golden_lines = [json.loads(l) for l in f.readlines()] - - grouped_golden = {} - for entry in golden_lines: - if entry["summary"] in grouped_golden: - grouped_golden[entry["summary"]]["templates"].append(entry["email"]) - else: - grouped_golden[entry["summary"]] = {} - grouped_golden[entry["summary"]]["templates"] = [(entry["email"])] - grouped_golden[entry["summary"]]["thread"] = entry["thread"] - - print("Evaluating with batch size", args.batch_size) - - results = {} - all_results = [] - prompt_scores = {} - outputs_logs = {} - grouped_golden = list(grouped_golden.items()) - for i in range(0, len(grouped_golden), args.batch_size): - batch = grouped_golden[i:i + args.batch_size] - prompts = [item[0] for item in batch] - if args.use_thread: - threads = [item[1]["thread"] for item in batch] - golden_responses = [item[1]["templates"] for item in batch] - - #prompt_scores = [[] for _ in range(len(prompts))] - for _ in range(args.responses_per_prompt): - if args.use_thread: - instructions = list(zip(prompts, threads)) - else: - instructions = list(zip(prompts, [None]*len(prompts))) - - full_prompts, outputs = base_inference.run_inference( - instructions=instructions, - model=model, - tokenizer=tokenizer, - system_preamble=system_preamble, - user_preamble=user_preamble, - rag_preamble=rag_preamble, - rag_relevance_threshold=args.rag_relevance_threshold, - rag_num_emails=args.rag_num_emails, - thread_preamble=thread_preamble, - use_rag=args.use_rag, - db=db if args.use_rag else None, - max_new_tokens=args.max_new_tokens, - best=args.best, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - device=args.device, - ) - - # Remove some boilerplate added by instruction-tuned models w/out finetuning. - outputs = [o.replace("Here is the email:\n", "") for o in outputs] - outputs = [re.sub(r'SUBJECT:.*\n', "", o) for o in outputs] - outputs = [re.sub(r'Subject:.*\n', "", o) for o in outputs] - outputs = [re.sub(r'E-MAIL CONTENT:.*\n', "", o) for o in outputs] - for j, prompt in enumerate(prompts): - # We clean up the strings for the BLEU and ROUGE scores. - punc_table = str.maketrans({key: None for key in string.punctuation}) - golden = [" ".join(x.translate(punc_table).lower().split()) for x in golden_responses[j]] - candidate = " ".join(outputs[j].translate(punc_table).lower().split()) - - rouge_score = rouge(outputs[j], golden_responses[j]) - bleu_score = np.mean([bleu([candidate], [golden]) for bleu in [bleu1, bleu2, bleu3, bleu4]]) - rouge_score = rouge(candidate, golden) - if prompt not in prompt_scores.keys(): - prompt_scores[prompt] = {"prompt": prompt, "full_prompt": full_prompts[j], - "golden" : golden_responses[j], "output": [outputs[j]], - "BLEU": [bleu_score.item()]} - for score, value in rouge_score.items(): - prompt_scores[prompt][score] = [value.item()] - else: - prompt_scores[prompt]["output"].append(outputs[j]) - prompt_scores[prompt]["BLEU"].append(bleu_score.item()) - for score, value in rouge_score.items(): - prompt_scores[prompt][score].append(value.item()) - - print("\n-----------\n", "PROMPT:\n", prompt, "\n\nOUTPUT:\n", outputs[j], "\n\nBLEU SCORE:\n", bleu_score, "\n\nROUGE SCORE:\n", rouge_score) - - - means = {} - mins = {} - score_names = [k for k in prompt_scores.values().__iter__().__next__().keys() if 'BLEU' in k or 'rouge' in k] - - for k in score_names: - means[k] = np.mean([v for scores in prompt_scores.values() for v in scores[k] ]) - mins[k] = np.min([v for scores in prompt_scores.values() for v in scores[k] ]) - - # To compute the MAUVE score, we need equal-length flat arrays of - # outputs and goldens. If we have multiple outputs per prompt, we - # output them all, with the same golden prompt. - # TODO: not sure if it would be better to randomly sample from the - # outputs in this case. - # TODO: consider handling the case where there are also multiple golden - # queries per output. (We don't use this for anything now). - flattened_golden = [] - flattened_outputs = [] - for prompt_info in prompt_scores.values(): - flattened_golden += ([prompt_info["golden"][0]])*len(prompt_info['output']) - flattened_outputs += prompt_info['output'] - mauve_score = mauve.compute(predictions=flattened_outputs, references=flattened_golden) - print("MAUVE score", mauve_score) - means["MAUVE"] = mauve_score.mauve - print("Mean scores across all prompts: ", {f" {k}: {v}" for k, v in means.items()}) - - - # Optionally, update wandb run with eval scores - if args.use_thread: - setting_str = "THREAD-" - elif args.use_rag: - setting_str = "RAG-" - else: - setting_str = "" - - if args.wandb_run_id: - with wandb.init(id=args.wandb_run_id, resume=True): - wandb.log({f"EVAL/{k}-{setting_str}mean": v for k, v in means.items()}) - wandb.log({f"EVAL/{k}-{setting_str}min": v for k, v in mins.items()}) - else: - print({f"EVAL/{k}-{setting_str}mean": v for k, v in means.items()}) - print({f"EVAL/{k}-{setting_str}min": v for k, v in mins.items()}) - - with open(os.path.join(args.model, f"{setting_str}eval_responses.txt"), 'w') as f: - json.dump(prompt_scores, f, ensure_ascii=False, indent=4) - - with open(os.path.join(args.model, f"{setting_str}eval_summary.txt"), 'w') as f: - json.dump({"means": means, "mins": mins}, f, ensure_ascii=False, indent=4) - -if __name__ == "__main__": - main() diff --git a/src/panza/evaluation/gui_inference.py b/src/panza/evaluation/gui_inference.py deleted file mode 100644 index 24d0b90..0000000 --- a/src/panza/evaluation/gui_inference.py +++ /dev/null @@ -1,89 +0,0 @@ -import argparse -import os -import sys - -import gradio as gr -import torch - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - -from panza.evaluation import base_inference -from panza.utils import prompting, rag - -sys.path.pop(0) - - -def get_execute(model, tokenizer, system_preamble, user_preamble, rag_preamble, db, args): - - def execute(prompt): - prompts, outputs = base_inference.run_inference( - instructions=[(prompt, None)], - model=model, - tokenizer=tokenizer, - system_preamble=system_preamble, - user_preamble=user_preamble, - rag_preamble=rag_preamble, - rag_relevance_threshold=args.rag_relevance_threshold, - rag_num_emails=args.rag_num_emails, - thread_preamble=None, - use_rag=args.use_rag, - db=db if args.use_rag else None, - max_new_tokens=args.max_new_tokens, - best=args.best, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - device=args.device, - ) - print("Prompt\n", prompts[0]) - print("Output\n", outputs[0]) - yield outputs[0] - - return execute - - -def main(): - parser = base_inference.get_base_inference_args_parser() - parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8001) - args = parser.parse_args() - - print("Running inference with args:", args) - - if args.nthreads is not None: - torch.set_num_threads(args.nthreads) - - print("Loading model ", args.model) - model, tokenizer = base_inference.load_model_and_tokenizer(args.model, args.device, args.dtype, load_in_4bit=args.load_in_4bit) - - if args.use_rag: - embeddings_model = rag.get_embeddings_model(args.embedding_model) - db = rag.load_vector_db_from_disk(args.db_path, args.index_name, embeddings_model) - - system_preamble, user_preamble, rag_preamble, _ = prompting.load_all_preambles( - args.system_preamble, args.user_preamble, args.rag_preamble, args.thread_preamble - ) - - with gr.Blocks() as panza: - gr.Markdown("# Panza\n") - inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") - outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") - inputbox.submit( - get_execute( - model=model, - tokenizer=tokenizer, - system_preamble=system_preamble, - user_preamble=user_preamble, - rag_preamble=rag_preamble, - db=db if args.use_rag else None, - args=args, - ), - [inputbox], - [outputbox], - ) - - panza.queue().launch(server_name=args.host, server_port=args.port, share=True) - - -if __name__ == "__main__": - main() diff --git a/src/panza/finetuning/configs/fft_panza.yaml b/src/panza/finetuning/configs/fft_panza.yaml deleted file mode 100644 index 7c2e128..0000000 --- a/src/panza/finetuning/configs/fft_panza.yaml +++ /dev/null @@ -1,93 +0,0 @@ -max_seq_len: 512 -global_seed: 17 -model_name_or_path: #TODO - -load_path: # set via bash script to be absolute path to your sparse checkpoint -precision: amp_bf16 -hf_save_path: ./checkpoints - -max_duration: # TODO -eval_interval: 1 -# eval_first: false -seed: ${global_seed} - -global_train_batch_size: #TODO -device_train_microbatch_size: 16 -device_eval_batch_size: 16 - -run_name: # If left blank, will be read from env var $RUN_NAME - -model: - name: hf_causal_lm - pretrained: true - pretrained_model_name_or_path: ${model_name_or_path} - max_seq_len: ${max_seq_len} - output_hidden_states: true - weight_bias_dtype: #TODO - compute_dtype: bf16 - -tokenizer: - name: ${model_name_or_path} - kwargs: - model_max_length: ${max_seq_len} - -train_loader: - name: finetuning - dataset: - hf_name: json - split: train - hf_kwargs: - data_files: #TODO - preprocessing_fn: preprocessing:panza_preprocessing_function - max_seq_len: ${max_seq_len} - allow_pad_trimming: false - decoder_only_format: true - shuffle: true - drop_last: false - num_workers: 8 - pin_memory: false - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -scheduler: - name: linear_decay_with_warmup - t_warmup: 20ba - alpha_f: 0 - -optimizer: - name: decoupled_adamw - lr: # TODO - betas: - - 0.9 - - 0.999 - eps: 1.0e-8 - weight_decay: 0.0 - -fsdp_config: - sharding_strategy: FULL_SHARD - mixed_precision: FULL - activation_checkpointing: true - activation_checkpointing_reentrant: false - activation_cpu_offload: false - limit_all_gathers: true - verbose: false - -progress_bar: false -log_to_console: true -console_log_interval: 1ba - -callbacks: - speed_monitor: - window_size: 10 - lr_monitor: { } - memory_monitor: { } - runtime_estimator: { } - hf_checkpointer: - overwrite: true - precision: # TODO - save_folder: ${hf_save_path}/${run_name} - save_interval: 1dur - -loggers: - wandb: { } diff --git a/src/panza/finetuning/configs/mistral_7b_fft_panza.yaml b/src/panza/finetuning/configs/mistral_7b_fft_panza.yaml deleted file mode 100644 index 1874b7e..0000000 --- a/src/panza/finetuning/configs/mistral_7b_fft_panza.yaml +++ /dev/null @@ -1,107 +0,0 @@ -# This config trains lora and spa the whole time, which means it restarts the training after grad collection. - -max_seq_len: 512 -global_seed: 17 -model_name_or_path: #TODO - -load_path: # set via bash script to be absolute path to your sparse checkpoint -precision: amp_bf16 -hf_save_path: ./checkpoints - -max_duration: # TODO -eval_interval: 1 -# eval_first: false -seed: ${global_seed} - -global_train_batch_size: #TODO -# for mpt-7b dense: -# 4 x A100_80GB = "device_train_microbatch_size: 12" -# 8 x A6000_48GB = "device_train_microbatch_size: 6" - -# for mpt-7b sparse (with masks): -# 8 x A6000_48GB = "device_train_microbatch_size: 4" -device_train_microbatch_size: 16 -device_eval_batch_size: 16 - -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME - -model: - name: hf_causal_lm - pretrained: true - pretrained_model_name_or_path: ${model_name_or_path} - max_seq_len: ${max_seq_len} - output_hidden_states: true - weight_bias_dtype: #TODO - compute_dtype: bf16 - # config_overrides: - # attn_config: - # attn_impl: torch - # Set this to `true` if using `train_loader.dataset.packing_ratio` below - # attn_uses_sequence_id: true - -# Tokenizer -tokenizer: - name: ${model_name_or_path} - kwargs: - model_max_length: ${max_seq_len} - -# Dataloaders -train_loader: - name: finetuning - dataset: - hf_name: json - split: train - hf_kwargs: - data_files: #TODO - preprocessing_fn: preprocessing:panza_preprocessing_function - max_seq_len: ${max_seq_len} - allow_pad_trimming: false - decoder_only_format: true - shuffle: true - drop_last: false - num_workers: 8 - pin_memory: false - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -# Optimization -scheduler: - name: linear_decay_with_warmup - t_warmup: 20ba - alpha_f: 0 - -optimizer: - name: decoupled_adamw - lr: # TODO - betas: - - 0.9 - - 0.999 - eps: 1.0e-8 - weight_decay: 0.0 - -# FSDP -fsdp_config: - sharding_strategy: FULL_SHARD - mixed_precision: FULL - activation_checkpointing: true - activation_checkpointing_reentrant: false - activation_cpu_offload: false - limit_all_gathers: true - verbose: false - -# Logging -progress_bar: false -log_to_console: true -console_log_interval: 1ba - -callbacks: - speed_monitor: - window_size: 10 - lr_monitor: { } - memory_monitor: { } - runtime_estimator: { } - -loggers: - wandb: { } diff --git a/src/panza/finetuning/configs/mistral_7b_rosa_panza.yaml b/src/panza/finetuning/configs/mistral_7b_rosa_panza.yaml deleted file mode 100644 index fa93a7b..0000000 --- a/src/panza/finetuning/configs/mistral_7b_rosa_panza.yaml +++ /dev/null @@ -1,113 +0,0 @@ -# This config trains lora and spa the whole time, which means it restarts the training after grad collection. - -max_seq_len: 512 -global_seed: 17 -model_name_or_path: #TODO - -load_path: # set via bash script to be absolute path to your sparse checkpoint -precision: amp_bf16 -hf_save_path: ./checkpoints - -max_duration: # TODO -eval_interval: 1 -# eval_first: false -seed: ${global_seed} - -global_train_batch_size: #TODO -# for mpt-7b dense: -# 4 x A100_80GB = "device_train_microbatch_size: 12" -# 8 x A6000_48GB = "device_train_microbatch_size: 6" - -# for mpt-7b sparse (with masks): -# 8 x A6000_48GB = "device_train_microbatch_size: 4" -device_train_microbatch_size: 16 -device_eval_batch_size: 16 - -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME - -model: - name: hf_causal_lm - pretrained: true - pretrained_model_name_or_path: ${model_name_or_path} - max_seq_len: ${max_seq_len} - output_hidden_states: true - weight_bias_dtype: #TODO - compute_dtype: bf16 - # config_overrides: - # attn_config: - # attn_impl: torch - # Set this to `true` if using `train_loader.dataset.packing_ratio` below - # attn_uses_sequence_id: true - -rosa: - lora_r: #TODO - spa_d: #TODO - lora_alpha: 16 - target_modules: 'all-linear' - lora_dropout: 0.05 - impl: auto - spa_store_transpose: true - rosa_dtype: bf16 - spa_num_grads: 1 - grad_acc_mode: mean_squared - mask_load_path: #TODO - mask_save_path: #TODO - terminate_after_mask_generation: #TODO - schedule: #TODO - -# Tokenizer -tokenizer: - name: ${model_name_or_path} - kwargs: - model_max_length: ${max_seq_len} - -# Dataloaders -train_loader: - name: finetuning - dataset: - hf_name: json - split: train - hf_kwargs: - data_files: #TODO - preprocessing_fn: preprocessing:panza_preprocessing_function - max_seq_len: ${max_seq_len} - allow_pad_trimming: false - decoder_only_format: true - shuffle: true - drop_last: false - num_workers: 8 - pin_memory: false - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -# Optimization -scheduler: - name: linear_decay_with_warmup - t_warmup: 20ba - alpha_f: 0 - -optimizer: - name: decoupled_adamw - lr: # TODO - betas: - - 0.9 - - 0.999 - eps: 1.0e-8 - weight_decay: 0.0 - -# Logging -progress_bar: false -log_to_console: true -console_log_interval: 1ba - -callbacks: - speed_monitor: - window_size: 10 - lr_monitor: { } - memory_monitor: { } - runtime_estimator: { } - -loggers: - wandb: { } diff --git a/src/panza/finetuning/configs/rosa_panza_colab.yaml b/src/panza/finetuning/configs/rosa_panza_colab.yaml deleted file mode 100644 index 0c292b1..0000000 --- a/src/panza/finetuning/configs/rosa_panza_colab.yaml +++ /dev/null @@ -1,95 +0,0 @@ -max_seq_len: 512 -global_seed: 17 -model_name_or_path: #TODO - -load_path: # set via bash script to be absolute path to your sparse checkpoint -precision: fp32 -hf_save_path: ./checkpoints - -max_duration: # TODO -eval_interval: 1 -seed: ${global_seed} - -global_train_batch_size: #TODO -device_train_microbatch_size: 16 -device_eval_batch_size: 16 - -run_name: # If left blank, will be read from env var $RUN_NAME - -model: - name: hf_causal_lm - pretrained: true - pretrained_model_name_or_path: ${model_name_or_path} - max_seq_len: ${max_seq_len} - output_hidden_states: true - weight_bias_dtype: #TODO - compute_dtype: fp32 - -rosa: - lora_r: #TODO - spa_d: #TODO - lora_alpha: 16 - target_modules: 'all-linear' - lora_dropout: 0.05 - impl: auto - spa_store_transpose: true - rosa_dtype: fp32 - spa_num_grads: 1 - grad_acc_mode: mean_squared - grad_4bit_accum: true - mask_load_path: #TODO - mask_save_path: #TODO - terminate_after_mask_generation: #TODO - schedule: #TODO - -tokenizer: - name: ${model_name_or_path} - kwargs: - model_max_length: ${max_seq_len} - -train_loader: - name: finetuning - dataset: - hf_name: json - split: train - hf_kwargs: - data_files: #TODO - preprocessing_fn: preprocessing:panza_preprocessing_function - max_seq_len: ${max_seq_len} - allow_pad_trimming: false - decoder_only_format: true - shuffle: true - drop_last: false - num_workers: 8 - pin_memory: false - prefetch_factor: 2 - persistent_workers: true - timeout: 0 - -scheduler: - name: linear_decay_with_warmup - t_warmup: 20ba - alpha_f: 0 - -optimizer: - name: decoupled_adamw - lr: # TODO - betas: - - 0.9 - - 0.999 - eps: 1.0e-8 - weight_decay: 0.0 - -progress_bar: false -log_to_console: true -console_log_interval: 1ba - -callbacks: - speed_monitor: - window_size: 10 - lr_monitor: { } - memory_monitor: { } - runtime_estimator: { } - -loggers: - wandb: { } diff --git a/src/panza/finetuning/preprocessing.py b/src/panza/finetuning/preprocessing.py index 1e29219..26e499b 100644 --- a/src/panza/finetuning/preprocessing.py +++ b/src/panza/finetuning/preprocessing.py @@ -1,143 +1,47 @@ import os -import random -from typing import Dict, List, Tuple +from typing import Dict -from langchain_core.documents import Document +import hydra +from omegaconf import OmegaConf +from transformers import AutoConfig, AutoTokenizer -from panza.utils import prompting, rag -from panza.utils.documents import Email +from panza.entities import EmailInstruction -SYSTEM_PREAMBLE_PATH = os.environ.get("PANZA_SYSTEM_PREAMBLE_PATH") -USER_PREAMBLE_PATH = os.environ.get("PANZA_USER_PREAMBLE_PATH") +PREPROCESSING_CONFIG_FILE = os.environ.get("PANZA_PREPROCESSING_CONFIG") +if PREPROCESSING_CONFIG_FILE: + preprocessing_config = OmegaConf.load(PREPROCESSING_CONFIG_FILE) + prompt_builder = hydra.utils.instantiate(preprocessing_config.prompting) -SYSTEM_PREAMBLE = prompting.load_preamble(SYSTEM_PREAMBLE_PATH) -USER_PREAMBLE = prompting.load_user_preamble(USER_PREAMBLE_PATH) + # Load tokenizer. The trust_remote_code parameter is necessary to load Phi-3.5. + config = AutoConfig.from_pretrained(preprocessing_config.model, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + preprocessing_config.model, model_max_length=config.max_position_embeddings + ) -PANZA_GENERATIVE_MODEL = os.environ.get("PANZA_GENERATIVE_MODEL") -PROMPT_START_WRAPPER, PROMPT_END_WRAPPER, RESPONSE_START_WRAPPER, RESPONSE_END_WRAPPER = ( - prompting.get_model_special_tokens(PANZA_GENERATIVE_MODEL) -) -PANZA_FINETUNE_WITH_RAG = int(os.environ.get("PANZA_FINETUNE_WITH_RAG")) == 1 -if PANZA_FINETUNE_WITH_RAG: - EMBEDDINGS_MODEL = os.environ.get("PANZA_EMBEDDING_MODEL") - DB_PATH = os.environ.get("PANZA_DATA_DIR") - INDEX_NAME = os.environ.get("PANZA_USERNAME") - EMBEDDINGS_MODEL = rag.get_embeddings_model(EMBEDDINGS_MODEL) - DB = rag.load_vector_db_from_disk(DB_PATH, INDEX_NAME, EMBEDDINGS_MODEL) - RAG_PREAMBLE_PATH = os.environ.get("PANZA_RAG_PREAMBLE_PATH") - RAG_PREAMBLE = prompting.load_preamble(RAG_PREAMBLE_PATH) - RAG_NUM_EMAILS = int(os.environ.get("PANZA_FINETUNE_RAG_NUM_EMAILS")) - RAG_PROB = float(os.environ.get("PANZA_FINETUNE_RAG_PROB")) - RAG_RELEVANCE_THRESHOLD = float(os.environ.get("PANZA_FINETUNE_RAG_RELEVANCE_THRESHOLD")) - PANZA_SEED = int(os.environ.get("PANZA_SEED")) - random.seed(PANZA_SEED) - -PANZA_FINETUNE_WITH_THREAD = int(os.environ.get("PANZA_FINETUNE_WITH_THREAD")) == 1 -if PANZA_FINETUNE_WITH_THREAD: - THREAD_PREAMBLE_PATH = os.environ.get("PANZA_THREAD_PREAMBLE_PATH") - THREAD_PREAMBLE = prompting.load_preamble(THREAD_PREAMBLE_PATH) - THREAD_NUM_EMAILS = int(os.environ.get("PANZA_FINETUNE_THREAD_NUM_EMAILS")) - -r"""Example custom preprocessing function. - -This is here to help illustrate the way to set up finetuning -on a local dataset. One step of that process is to create -a preprocessing function for your dataset, and that is what -is done below. Check out the LLM Finetuning section of -`../README.md` for more context. - -For this example, we're going to pretend that our local dataset -is `./train.jsonl`. - -Note: this dataset is actually a copy of one of our ARC-Easy -multiple-choice ICL eval datasets. And you would never actually -train on eval data! ... But this is just a demonstration. - -Every example within the dataset has the format: -{ - 'query': , - 'choices': [, , ...], - 'gold': # index of correct choice -} - -To enable finetuning, we want to turn this into a prompt/response -format. We'll structure prompts and responses like this: -{ - 'prompt': \nOptions:\n - \n - \nAnswer: , - 'response': -} -""" - - -def filter_relevant_emails(relevant_emails_with_score: List[Tuple[Email, float]]) -> List[Email]: - # Random chance to not include any relevant emails - p = random.random() - if p > RAG_PROB: - relevant_emails = [] - print("Skip RAG") - return relevant_emails - - if not relevant_emails: - print("Relevant emails not found.") - return [] - - print("Don't skip") - relevant_emails = [r["email"] for r in relevant_emails if r["score"] >= RAG_RELEVANCE_THRESHOLD] - relevant_emails = [Document(page_content=email, metadata={}) for email in relevant_emails] - relevant_emails = relevant_emails[:RAG_NUM_EMAILS] - print(f"Found {len(relevant_emails)} relevant emails.") - return relevant_emails - - -def panza_preprocessing_function(inp: Dict) -> Dict: +def panza_preprocessing_function(inputs: Dict) -> Dict: try: - prompt_raw = inp["summary"].split("\n\nInstruction: ")[-1] - return { - "prompt": PROMPT_START_WRAPPER + prompt_raw + PROMPT_END_WRAPPER, - "response": RESPONSE_START_WRAPPER + inp["email"] + RESPONSE_END_WRAPPER, - } - except Exception as e: - raise ValueError(f"Unable to extract prompt/response from {inp}") from e + prompt_raw = inputs["summary"].split("\n\nInstruction: ")[-1] + instruction = EmailInstruction(instruction=prompt_raw, thread=inputs.get("thread", [])) + prompt = prompt_builder.build_prompt(instruction) + # Generate the full conversation + conversation = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": inputs["email"]}, + ] + chat_prompt = tokenizer.apply_chat_template(conversation, tokenize=False) -def panza_preprocessing_function_train_with_preamble(inp: Dict) -> Dict: - try: - prompt_raw = inp["summary"].split("\n\nInstruction: ")[-1] - if PANZA_FINETUNE_WITH_RAG: - relevant_emails_with_score = inp.get("relevant_emails", []) - relevant_emails_with_score = [ - (Email.deserialize(email), score) for (email, score) in relevant_emails_with_score - ] - relevant_emails = filter_relevant_emails(relevant_emails_with_score) - prompt = prompting.create_prompt( - prompt_raw, SYSTEM_PREAMBLE, USER_PREAMBLE, RAG_PREAMBLE, relevant_emails - ) - print(prompt) - else: - prompt = prompting.create_prompt(prompt_raw, SYSTEM_PREAMBLE, USER_PREAMBLE) - return { - "prompt": PROMPT_START_WRAPPER + prompt + PROMPT_END_WRAPPER, - "response": RESPONSE_START_WRAPPER + inp["email"] + RESPONSE_END_WRAPPER, - } - except Exception as e: - raise ValueError(f"Unable to extract prompt/response from {inp}") from e + # Identify the index where the response begins + response_begin_index = chat_prompt.index(inputs["email"]) + # Split the full prompt into prompt and response + prompt = chat_prompt[:response_begin_index] + response = chat_prompt[response_begin_index:] -def panza_preprocessing_function_train_with_thread(inp: Dict) -> Dict: - try: - prompt_raw = inp["summary"].split("\n\nInstruction: ")[-1] - if PANZA_FINETUNE_WITH_THREAD: - thread = inp.get("thread", []) - thread = thread[:THREAD_NUM_EMAILS] - prompt = prompting.create_prompt( - prompt_raw, SYSTEM_PREAMBLE, USER_PREAMBLE, thread_preamble=THREAD_PREAMBLE, thread_emails=thread - ) - else: - prompt = prompting.create_prompt(prompt_raw, SYSTEM_PREAMBLE, USER_PREAMBLE) return { - "prompt": PROMPT_START_WRAPPER + prompt + PROMPT_END_WRAPPER, - "response": RESPONSE_START_WRAPPER + inp["email"] + RESPONSE_END_WRAPPER, + "prompt": prompt, + "response": response, } except Exception as e: - raise ValueError(f"Unable to extract prompt/response from {inp}") from e + raise ValueError(f"Unable to extract prompt/response from {inputs}") from e diff --git a/src/panza/finetuning/train.py b/src/panza/finetuning/train.py index 82c7c23..8c7b9e5 100644 --- a/src/panza/finetuning/train.py +++ b/src/panza/finetuning/train.py @@ -1,206 +1,325 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 + import copy import gc +import glob import logging import os -import sys +import shutil +import tempfile import time import warnings +from pathlib import Path from typing import Any, Dict, List, Optional, Union +import spops import torch -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig - -from composer.optim import DecoupledAdamW - -from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy, - InContextLearningLMAccuracy, - InContextLearningLMExpectedCalibrationError, - InContextLearningMCExpectedCalibrationError, - InContextLearningMultipleChoiceAccuracy, - InContextLearningQAAccuracy, - LanguageCrossEntropy, LanguagePerplexity) - -from llmfoundry.models.utils import init_empty_weights - -from transformers import PreTrainedTokenizerBase, AutoModelForCausalLM, BitsAndBytesConfig - -from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP - -from llmfoundry import ComposerHFCausalLM - -import os, sys -from peft.tuners.rosa import RosaModel, RosaScheduler, RosaConfig -from peft import get_peft_model - from composer import Trainer from composer.core.callback import Callback -from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, - cyclic_schedule) +from composer.metrics.nlp import ( + InContextLearningCodeEvalAccuracy, + InContextLearningLMAccuracy, + InContextLearningLMExpectedCalibrationError, + InContextLearningMCExpectedCalibrationError, + InContextLearningMultipleChoiceAccuracy, + InContextLearningQAAccuracy, + LanguageCrossEntropy, + LanguagePerplexity, +) +from composer.optim import DecoupledAdamW +from composer.profiler import JSONTraceHandler, Profiler, TraceHandler, cyclic_schedule from composer.utils import dist, get_device, reproducibility +from datasets import disable_caching +from llmfoundry import ComposerHFCausalLM +from llmfoundry.eval.metrics.nlp import InContextLearningMetric +from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP +from llmfoundry.models.utils import init_empty_weights +from llmfoundry.utils import find_mosaicml_logger, log_train_analytics, maybe_create_mosaicml_logger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from peft import get_peft_model +from peft.tuners.rosa import RosaConfig, RosaModel, RosaScheduler from rich.traceback import install - -from llmfoundry.eval.metrics.nlp import InContextLearningMetric -from llmfoundry.utils import (find_mosaicml_logger, log_train_analytics, - maybe_create_mosaicml_logger) +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedTokenizerBase install() from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader from llmfoundry.layers_registry import ffns_with_megablocks -from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, - build_algorithm, build_callback, - build_composer_model, build_evaluators, - build_logger, build_optimizer, - build_scheduler, build_tokenizer) -from llmfoundry.utils.config_utils import (log_config, pop_config, - process_init_device, - update_batch_size_info) +from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, + build_algorithm, + build_callback, + build_evaluators, + build_logger, + build_optimizer, + build_scheduler, + build_tokenizer, +) +from llmfoundry.utils.config_utils import ( + log_config, + pop_config, + process_init_device, + update_batch_size_info, +) from llmfoundry.utils.registry_utils import import_file +import hydra +from omegaconf import DictConfig, OmegaConf + +from panza import PanzaWriter # The import also loads custom Hydra resolvers + log = logging.getLogger(__name__) def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" loaders = [cfg.train_loader] - if 'eval_loader' in cfg: + if "eval_loader" in cfg: eval_loader = cfg.eval_loader if isinstance(eval_loader, ListConfig): for loader in eval_loader: if loader.label is None: raise ValueError( - 'When specifying multiple evaluation datasets, each one must include the \ - `label` attribute.') + "When specifying multiple evaluation datasets, each one must include the \ + `label` attribute." + ) loaders.append(loader) else: loaders.append(eval_loader) for loader in loaders: - if loader.name == 'text': - if cfg.model.name == 'hf_t5': + if loader.name == "text": + if cfg.model.name == "hf_t5": raise ValueError( - f'Model type "{cfg.model.name}" is not supported when using the "text " ' +\ - f'dataloader. Only finetuning is supported.') + f'Model type "{cfg.model.name}" is not supported when using the "text " ' + + f"dataloader. Only finetuning is supported." + ) - if 'icl_tasks' in cfg: - if cfg.model.name == 'hf_t5': + if "icl_tasks" in cfg: + if cfg.model.name == "hf_t5": raise ValueError( 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".' ) - if (cfg.model.get('fc_type', 'torch') != 'te' and 'te' not in cfg.model.get( - 'ffn_config', {}).get('ffn_type', 'mptmlp') and - 'fp8' in cfg.precision): + if ( + cfg.model.get("fc_type", "torch") != "te" + and "te" not in cfg.model.get("ffn_config", {}).get("ffn_type", "mptmlp") + and "fp8" in cfg.precision + ): warnings.warn( "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " - + - "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision." + + "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision." ) - if (cfg.model.get('fc_type', 'torch') == 'te' or - 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp')): - fsdp_config = cfg.get('fsdp_config', None) - act_ckpt = fsdp_config.get('activation_checkpointing', False) - act_ckpt_reentrant = fsdp_config.get( - 'activation_checkpointing_reentrant', False) + if cfg.model.get("fc_type", "torch") == "te" or "te" in cfg.model.get("ffn_config", {}).get( + "ffn_type", "mptmlp" + ): + fsdp_config = cfg.get("fsdp_config", None) + act_ckpt = fsdp_config.get("activation_checkpointing", False) + act_ckpt_reentrant = fsdp_config.get("activation_checkpointing_reentrant", False) if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( - '`te.Linear` layers do not support activation_checkpointing with ' - + '`activation_checkpointing_reentrant = True`. ' + - 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.' + "`te.Linear` layers do not support activation_checkpointing with " + + "`activation_checkpointing_reentrant = True`. " + + "Setting cfg.fsdp_config.activation_checkpointing_reentrant=False." ) cfg.fsdp_config.activation_checkpointing_reentrant = False - if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': + if cfg.model.get("ffn_config", {}).get("ffn_type", "mptmlp") == "te_ln_mlp": warnings.warn( - '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + - 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.' + "`te.LayerNormMLP` requires has issues with torch._dynamo. " + + "Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager." ) torch._dynamo.config.suppress_errors = True # type: ignore (third-party) - if cfg.model.get('load_in_8bit', False): - raise ValueError( - '`load_in_8bit` is only supported for evaluation rather than training.' - ) + if cfg.model.get("load_in_8bit", False): + raise ValueError("`load_in_8bit` is only supported for evaluation rather than training.") - if cfg.model.get('ffn_config', {}).get('ffn_type', - 'mptmlp') in ffns_with_megablocks: - moe_world_size = cfg.model.get('ffn_config', - {}).get('moe_world_size', 1) - use_orig_params = cfg.get('fsdp_config', - {}).get('use_orig_params', True) + if cfg.model.get("ffn_config", {}).get("ffn_type", "mptmlp") in ffns_with_megablocks: + moe_world_size = cfg.model.get("ffn_config", {}).get("moe_world_size", 1) + use_orig_params = cfg.get("fsdp_config", {}).get("use_orig_params", True) if moe_world_size > 1 and not use_orig_params: raise ValueError( - f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.' + f"MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`." ) +def create_run_name(cfg: DictConfig) -> str: + # export RUN_NAME=panza_${PANZA_USERNAME}_${MODEL_TYPE}_${MODEL_PRECISION}-bs${BS}-fft-lr${LR}-epochs${NUM_EPOCHS}-wu${WARMUP}-seed${SEED}${PREAMBLE_STR}${RAFT_STR} + + run_name = f"panza_{cfg.user.username}" + + model_name = cfg.finetuning.model_name_or_path.split("/")[-1] + run_name += f"-{model_name}" + + run_name += f"-{cfg.model_precision}" + run_name += f"-bs{cfg.finetuning.batch_size}" + + if hasattr(cfg.finetuning, "rosa"): + run_name += "-rosa" + else: + run_name += "-fft" + + run_name += f"-lr{cfg.finetuning.lr}" + run_name += f"-{cfg.finetuning.max_duration}" + run_name += f"-seed{cfg.finetuning.seed}" + + return run_name + + +def override_rosa_schedule(cfg: DictConfig, mask_generation=False) -> None: + # Disable struct mode to allow modifications + rosa_cfg = cfg.finetuning.rosa + OmegaConf.set_struct(rosa_cfg, False) + + mask_path = str(Path(cfg.checkpoint_dir) / "masks" / cfg.finetuning.run_name) + + if mask_generation: + rosa_cfg.schedule = "wl16" if rosa_cfg.lora_r != 0 else "spa_only" + rosa_cfg.mask_load_path = None + rosa_cfg.mask_save_path = mask_path + rosa_cfg.terminate_after_mask_generation = True + rosa_cfg.mask_gen_model_precision = "amp_bf16" + else: + if rosa_cfg.spa_d > 0 and rosa_cfg.lora_r != 0: + rosa_cfg.schedule = "default" + elif rosa_cfg.lora_r != 0: + rosa_cfg.schedule = "lora_only" + rosa_cfg.mask_load_path = None + else: + rosa_cfg.schedule = "spa_only" + + rosa_cfg.mask_load_path = mask_path + rosa_cfg.mask_save_path = None + rosa_cfg.terminate_after_mask_generation = None + + # Re-enable struct mode to lock down the configuration + OmegaConf.set_struct(rosa_cfg, True) + + +def create_checkpoint_dirs(cfg: DictConfig) -> None: + # Create model directory + os.makedirs(os.path.join(cfg.checkpoint_dir, "models"), exist_ok=True) + + # Create mask directory + if hasattr(cfg.finetuning, "rosa"): + os.makedirs(os.path.join(cfg.checkpoint_dir, "masks"), exist_ok=True) + + +def get_hf_save_precision(cfg: DictConfig) -> str: + if cfg.model_precision == "bf16": + return "bfloat16" + elif cfg.model_precision == "fp32": + return "float32" + else: + raise ValueError(f"Unsupported model_precision: {cfg.model_precision}") + + +def get_rosa_dtype(cfg: DictConfig) -> str: + if cfg.model_precision == "bf16": + return "bf16" + elif cfg.model_precision == "fp32": + return "fp32" + elif cfg.model_precision == "4bit": + return "fp32" + else: + raise ValueError(f"Unsupported model_precision: {cfg.model_precision}") + + +def override_config(cfg: DictConfig) -> None: + # Disable struct mode to allow modifications + OmegaConf.set_struct(cfg, False) + + if not cfg.finetuning.run_name: + cfg.finetuning.run_name = create_run_name(cfg) + + if hasattr(cfg.finetuning, "rosa"): + cfg.finetuning.rosa.rosa_dtype = get_rosa_dtype(cfg) + if cfg.finetuning.rosa.spa_d != 0: + override_rosa_schedule(cfg, mask_generation=cfg.finetuning.rosa.masks_only) + else: + cfg.finetuning.callbacks.hf_checkpointer.precision = get_hf_save_precision(cfg) + + # Re-enable struct mode to lock down the configuration + OmegaConf.set_struct(cfg, True) + + +def save_config_to_yaml(cfg: DictConfig) -> str: + cfg = OmegaConf.to_container(cfg, resolve=True) + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".yaml") as temp_file: + OmegaConf.save(config=cfg, f=temp_file.name) + return temp_file.name + + def build_composer_peft_model( - model_config: str, rosa_config: Dict[str, Any], - tokenizer: PreTrainedTokenizerBase, is_fsdp: bool = False) -> ComposerHFCausalLM: + model_config: str, + rosa_config: Dict[str, Any], + tokenizer: PreTrainedTokenizerBase, + is_fsdp: bool = False, +) -> ComposerHFCausalLM: # 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM. - print('Building model from HuggingFace checkpoint...') + print("Building model from HuggingFace checkpoint...") - weight_bias_dtype = model_config.get('weight_bias_dtype', None) - if weight_bias_dtype == '4bit': + weight_bias_dtype = model_config.get("weight_bias_dtype", None) + if weight_bias_dtype == "4bit": compute_dtype = torch.bfloat16 quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4', + bnb_4bit_quant_type="nf4", ) - elif weight_bias_dtype == 'bf16': - assert weight_bias_dtype == 'bf16', 'Only bf16 is supported for now' - compute_dtype = torch.bfloat16 - quant_config = None + elif weight_bias_dtype == "bf16": + compute_dtype = torch.bfloat16 + quant_config = None else: - assert weight_bias_dtype == 'fp32' + assert weight_bias_dtype == "fp32" compute_dtype = torch.float32 quant_config = None with init_empty_weights(include_buffers=False): model = AutoModelForCausalLM.from_pretrained( model_config.pretrained_model_name_or_path, - device_map='cpu' if quant_config is None else 'auto', + device_map="cpu" if quant_config is None else "auto", torch_dtype=compute_dtype, # load_in_4bit=weight_bias_dtype == '4bit', quantization_config=quant_config, trust_remote_code=True, use_auth_token=True, use_cache=False, - attn_implementation='eager' + attn_implementation="eager", ) - print('Model built!') + print("Model built!") if rosa_config is not None: - print('Building RoSA config...') + print("Building RoSA config...") config = RosaConfig( - r=rosa_config['lora_r'], - d=rosa_config['spa_d'], - lora_alpha=rosa_config.get('lora_alpha', 16), - target_modules=rosa_config.get('target_modules', 'all-linear'), - lora_dropout=rosa_config.get('lora_dropout', 0.05), - impl=rosa_config.get('impl', 'auto'), - spa_store_transpose=rosa_config.get('spa_store_transpose', True), - rosa_dtype=rosa_config.get('rosa_dtype', True), - spa_num_grads=rosa_config.get('spa_num_grads', 1), - grad_acc_mode=rosa_config.get('grad_acc_mode', 'mean_squared'), - grad_4bit_accum=rosa_config.get('grad_4bit_accum', False), - mask_load_path=rosa_config.get('mask_load_path', None), - mask_save_path=rosa_config.get('mask_save_path', None), - terminate_after_mask_generation=rosa_config.get('terminate_after_mask_generation', False), - schedule=rosa_config.get('schedule', 'df'), + r=rosa_config["lora_r"], + d=rosa_config["spa_d"], + lora_alpha=rosa_config.get("lora_alpha", 16), + target_modules=rosa_config.get("target_modules", "all-linear"), + lora_dropout=rosa_config.get("lora_dropout", 0.05), + impl=rosa_config.get("impl", "auto"), + spa_store_transpose=rosa_config.get("spa_store_transpose", True), + rosa_dtype=rosa_config.get("rosa_dtype", True), + spa_num_grads=rosa_config.get("spa_num_grads", 1), + grad_acc_mode=rosa_config.get("grad_acc_mode", "mean_squared"), + grad_4bit_accum=rosa_config.get("grad_4bit_accum", False), + mask_load_path=rosa_config.get("mask_load_path", None), + mask_save_path=rosa_config.get("mask_save_path", None), + terminate_after_mask_generation=rosa_config.get( + "terminate_after_mask_generation", False + ), + schedule=rosa_config.get("schedule", "df"), bias="none", task_type="CAUSAL_LM", ) - print('Adding RoSA modules...') + # raise ValueError(config) + print("Adding RoSA modules...") model = get_peft_model(model, config) - print('RoSA modules added!') + print("RoSA modules added!") train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] eval_metrics = [ @@ -211,7 +330,7 @@ def build_composer_peft_model( InContextLearningQAAccuracy(), InContextLearningCodeEvalAccuracy(), InContextLearningLMExpectedCalibrationError(), - InContextLearningMCExpectedCalibrationError() + InContextLearningMCExpectedCalibrationError(), ] model = HuggingFaceModelWithFSDP( @@ -220,31 +339,48 @@ def build_composer_peft_model( tokenizer=tokenizer, metrics=train_metrics, eval_metrics=eval_metrics, - init_device='cpu', - peft_config=None + init_device="cpu", + peft_config=None, ) - - # model = ComposerHFCausalLM(model, tokenizer) - # model = ModelComposerHFCausalLM(model, tokenizer) return model + +@hydra.main(version_base="1.1", config_path="../../../configs", config_name="panza_finetuning") def main(cfg: DictConfig) -> Trainer: + override_config(cfg) + + # Resolve all interpolation variables as early as possible + om.resolve(cfg) + + # The preprocessing config is saved to a temporary directory + # and accessed through an environment variable. Note that this + # happens separately for each process (however, a collision should) + # not be a problem, since the configs are the same. + OmegaConf.set_struct(cfg, False) + cfg.preprocessing.model = cfg.finetuning.model_name_or_path + preprocessing_yaml = save_config_to_yaml(cfg.preprocessing) + + environment = os.environ + environment["WANDB_PROJECT"] = f"panza-{cfg.user.username}" + environment["WANDB_DISABLED"] = str(int(cfg.finetuning.wandb_disabled)) + environment["PANZA_PREPROCESSING_CONFIG"] = preprocessing_yaml + + cfg = cfg.finetuning + + # Make the config editable for popping. + OmegaConf.set_struct(cfg, False) + # Run user provided code if specified - code_paths = pop_config(cfg, - 'code_paths', - must_exist=False, - default_value=[], - convert=True) + code_paths = pop_config(cfg, "code_paths", must_exist=False, default_value=[], convert=True) # Import any user provided code for code_path in code_paths: import_file(code_path) # Filter deprecation warning from torch internal usage warnings.filterwarnings( - action='ignore', + action="ignore", category=UserWarning, - message= - 'torch.distributed.*_base is a private function and will be deprecated.*' + message="torch.distributed.*_base is a private function and will be deprecated.*", ) # Check for incompatibilities between the model and data loaders @@ -258,32 +394,31 @@ def main(cfg: DictConfig) -> Trainer: cuda_alloc_conf = [] # Get max split size mb - max_split_size_mb: Optional[int] = cfg.pop('max_split_size_mb', None) + max_split_size_mb: Optional[int] = cfg.pop("max_split_size_mb", None) if max_split_size_mb is not None: - cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') + cuda_alloc_conf.append(f"max_split_size_mb:{max_split_size_mb}") # Expandable segments - if cfg.pop('expandable_segments', False): - cuda_alloc_conf.append('expandable_segments:True') + if cfg.pop("expandable_segments", False): + cuda_alloc_conf.append("expandable_segments:True") if len(cuda_alloc_conf) > 0: - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(cuda_alloc_conf) # Set CUDA lazy loading # This can save a bit of memory if not all modules are needed - cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', False) + cuda_load_lazy: bool = cfg.pop("cuda_load_lazy", False) if cuda_load_lazy: - os.environ['CUDA_MODULE_LOADING'] = 'LAZY' + os.environ["CUDA_MODULE_LOADING"] = "LAZY" # Set seed first - seed: int = pop_config(cfg, 'seed', must_exist=True) + seed: int = pop_config(cfg, "seed", must_exist=True) reproducibility.seed_all(seed) # Initialize pytorch distributed training process groups - dist_timeout: Union[int, float] = pop_config(cfg, - 'dist_timeout', - must_exist=False, - default_value=600.0) + dist_timeout: Union[int, float] = pop_config( + cfg, "dist_timeout", must_exist=False, default_value=600.0 + ) dist.initialize_dist(get_device(None), timeout=dist_timeout) # Get global and device batch size information from distributed/single node setting @@ -291,244 +426,175 @@ def main(cfg: DictConfig) -> Trainer: logged_cfg.update(cfg, merge=True) # Mandatory model training configs - model_config: DictConfig = pop_config(cfg, 'model', must_exist=True) - tokenizer_config: Dict[str, Any] = pop_config(cfg, - 'tokenizer', - must_exist=True, - convert=True) - optimizer_config: Dict[str, Any] = pop_config(cfg, - 'optimizer', - must_exist=True, - convert=True) - scheduler_config: Dict[str, Any] = pop_config(cfg, - 'scheduler', - must_exist=True, - convert=True) - train_loader_config: DictConfig = pop_config(cfg, - 'train_loader', - must_exist=True) + model_config: DictConfig = pop_config(cfg, "model", must_exist=True) + tokenizer_config: Dict[str, Any] = pop_config(cfg, "tokenizer", must_exist=True, convert=True) + optimizer_config: Dict[str, Any] = pop_config(cfg, "optimizer", must_exist=True, convert=True) + scheduler_config: Dict[str, Any] = pop_config(cfg, "scheduler", must_exist=True, convert=True) + train_loader_config: DictConfig = pop_config(cfg, "train_loader", must_exist=True) # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = pop_config(cfg, - 'fsdp_config', - must_exist=False, - default_value=None, - convert=True) - - ds_config: Optional[Dict[str, Any]] = pop_config(cfg, - 'ds_config', - must_exist=False, - default_value=None, - convert=True) - - rosa_config: Optional[Dict[str, Any]] = pop_config(cfg, - 'rosa', - must_exist=False, - default_value=None, - convert=True) - - hf_save_path: Union[int, str] = pop_config(cfg, - 'hf_save_path', - must_exist=True) - + fsdp_config: Optional[Dict[str, Any]] = pop_config( + cfg, "fsdp_config", must_exist=False, default_value=None, convert=True + ) + + ds_config: Optional[Dict[str, Any]] = pop_config( + cfg, "ds_config", must_exist=False, default_value=None, convert=True + ) + + rosa_config: Optional[Dict[str, Any]] = pop_config( + cfg, "rosa", must_exist=False, default_value=None, convert=True + ) + + hf_save_path: Union[int, str] = pop_config(cfg, "hf_save_path", must_exist=True) + eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config( - cfg, 'eval_loader', must_exist=False, default_value=None) - icl_tasks_config: Optional[Union[ListConfig, - str]] = pop_config(cfg, - 'icl_tasks', - must_exist=False, - default_value=None) - eval_gauntlet_config: Optional[Union[DictConfig, - str]] = pop_config(cfg, - 'eval_gauntlet', - must_exist=False, - default_value=None) - icl_subset_num_batches: Optional[int] = pop_config(cfg, - 'icl_subset_num_batches', - must_exist=False, - default_value=None) - icl_seq_len: Optional[int] = pop_config(cfg, - 'icl_seq_len', - must_exist=False, - default_value=None) + cfg, "eval_loader", must_exist=False, default_value=None + ) + icl_tasks_config: Optional[Union[ListConfig, str]] = pop_config( + cfg, "icl_tasks", must_exist=False, default_value=None + ) + eval_gauntlet_config: Optional[Union[DictConfig, str]] = pop_config( + cfg, "eval_gauntlet", must_exist=False, default_value=None + ) + icl_subset_num_batches: Optional[int] = pop_config( + cfg, "icl_subset_num_batches", must_exist=False, default_value=None + ) + icl_seq_len: Optional[int] = pop_config( + cfg, "icl_seq_len", must_exist=False, default_value=None + ) # Optional logging, evaluation and callback configs - logger_configs: Optional[DictConfig] = pop_config(cfg, - 'loggers', - must_exist=False, - default_value=None, - convert=True) - callback_configs: Optional[DictConfig] = pop_config(cfg, - 'callbacks', - must_exist=False, - default_value=None, - convert=True) - algorithm_configs: Optional[DictConfig] = pop_config(cfg, - 'algorithms', - must_exist=False, - default_value=None) + logger_configs: Optional[DictConfig] = pop_config( + cfg, "loggers", must_exist=False, default_value=None, convert=True + ) + callback_configs: Optional[DictConfig] = pop_config( + cfg, "callbacks", must_exist=False, default_value=None, convert=True + ) + algorithm_configs: Optional[DictConfig] = pop_config( + cfg, "algorithms", must_exist=False, default_value=None + ) # Mandatory hyperparameters for training - device_train_batch_size: int = pop_config(cfg, - 'device_train_batch_size', - must_exist=True) - device_eval_batch_size: int = pop_config(cfg, - 'device_eval_batch_size', - must_exist=True) - max_duration: Union[int, str] = pop_config(cfg, - 'max_duration', - must_exist=True) - eval_interval: Union[int, str] = pop_config(cfg, - 'eval_interval', - default_value=1, - must_exist=False) - precision: str = pop_config(cfg, 'precision', must_exist=True) - max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) + device_train_batch_size: int = pop_config(cfg, "device_train_batch_size", must_exist=True) + device_eval_batch_size: int = pop_config(cfg, "device_eval_batch_size", must_exist=True) + max_duration: Union[int, str] = pop_config(cfg, "max_duration", must_exist=True) + eval_interval: Union[int, str] = pop_config( + cfg, "eval_interval", default_value=1, must_exist=False + ) + precision: str = pop_config(cfg, "precision", must_exist=True) + max_seq_len: int = pop_config(cfg, "max_seq_len", must_exist=True) # Optional parameters will be set to default values if not specified. - default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = pop_config(cfg, - 'run_name', - must_exist=False, - default_value=default_run_name) - save_folder: Optional[str] = pop_config(cfg, - 'save_folder', - must_exist=False, - default_value=None) - is_state_dict_sharded: bool = (fsdp_config.get('state_dict_type', 'full') - == 'sharded') if fsdp_config else False + default_run_name: str = os.environ.get("RUN_NAME", "llm") + run_name: str = pop_config(cfg, "run_name", must_exist=False, default_value=default_run_name) + save_folder: Optional[str] = pop_config( + cfg, "save_folder", must_exist=False, default_value=None + ) + is_state_dict_sharded: bool = ( + (fsdp_config.get("state_dict_type", "full") == "sharded") if fsdp_config else False + ) save_latest_filename: str = pop_config( cfg, - 'save_latest_filename', + "save_latest_filename", must_exist=False, - default_value='latest-sharded-rank{rank}' - if is_state_dict_sharded else 'latest-rank{rank}.pt') - save_overwrite: bool = pop_config(cfg, - 'save_overwrite', - must_exist=False, - default_value=False) - save_weights_only: bool = pop_config(cfg, - 'save_weights_only', - must_exist=False, - default_value=False) + default_value=( + "latest-sharded-rank{rank}" if is_state_dict_sharded else "latest-rank{rank}.pt" + ), + ) + save_overwrite: bool = pop_config(cfg, "save_overwrite", must_exist=False, default_value=False) + save_weights_only: bool = pop_config( + cfg, "save_weights_only", must_exist=False, default_value=False + ) save_filename: str = pop_config( - cfg, - 'save_filename', - must_exist=False, - default_value='ep{epoch}-ba{batch}-rank{rank}.pt') - save_interval: Union[str, int] = pop_config(cfg, - 'save_interval', - must_exist=False, - default_value='1000ba') + cfg, "save_filename", must_exist=False, default_value="ep{epoch}-ba{batch}-rank{rank}.pt" + ) + save_interval: Union[str, int] = pop_config( + cfg, "save_interval", must_exist=False, default_value="1000ba" + ) save_num_checkpoints_to_keep: int = pop_config( - cfg, 'save_num_checkpoints_to_keep', must_exist=False, default_value=-1) - progress_bar = pop_config(cfg, - 'progress_bar', - must_exist=False, - default_value=False) - log_to_console: bool = pop_config(cfg, - 'log_to_console', - must_exist=False, - default_value=True) - python_log_level: Optional[str] = pop_config(cfg, - 'python_log_level', - must_exist=False, - default_value='debug') - console_log_interval: Union[int, str] = pop_config(cfg, - 'console_log_interval', - must_exist=False, - default_value='1ba') + cfg, "save_num_checkpoints_to_keep", must_exist=False, default_value=-1 + ) + progress_bar = pop_config(cfg, "progress_bar", must_exist=False, default_value=False) + log_to_console: bool = pop_config(cfg, "log_to_console", must_exist=False, default_value=True) + python_log_level: Optional[str] = pop_config( + cfg, "python_log_level", must_exist=False, default_value="debug" + ) + console_log_interval: Union[int, str] = pop_config( + cfg, "console_log_interval", must_exist=False, default_value="1ba" + ) device_train_microbatch_size: Union[str, int] = pop_config( - cfg, - 'device_train_microbatch_size', - must_exist=False, - default_value='auto') - eval_subset_num_batches: int = pop_config(cfg, - 'eval_subset_num_batches', - must_exist=False, - default_value=-1) - eval_first: bool = pop_config(cfg, - 'eval_first', - must_exist=False, - default_value=False) - load_path: str = pop_config(cfg, - 'load_path', - must_exist=False, - default_value=None) - load_weights_only: bool = pop_config(cfg, - 'load_weights_only', - must_exist=False, - default_value=False) - load_strict_model_weights: bool = pop_config(cfg, - 'load_strict_model_weights', - must_exist=False, - default_value=True) - load_ignore_keys: Optional[List[str]] = pop_config(cfg, - 'load_ignore_keys', - must_exist=False, - default_value=None) - save_ignore_keys: Optional[List[str]] = pop_config(cfg, - 'save_ignore_keys', - must_exist=False, - default_value=None) - compile_config: Optional[Dict[str, Any]] = pop_config(cfg, - 'compile_config', - must_exist=False, - default_value=None) - metadata: Optional[Dict[str, str]] = pop_config(cfg, - 'metadata', - must_exist=False, - default_value=None, - convert=True) - should_log_config: bool = pop_config(cfg, - 'log_config', - must_exist=False, - default_value=True) - - num_cpu_threads: Optional[int] = cfg.pop('num_cpu_threads', 0) + cfg, "device_train_microbatch_size", must_exist=False, default_value="auto" + ) + eval_subset_num_batches: int = pop_config( + cfg, "eval_subset_num_batches", must_exist=False, default_value=-1 + ) + eval_first: bool = pop_config(cfg, "eval_first", must_exist=False, default_value=False) + load_path: str = pop_config(cfg, "load_path", must_exist=False, default_value=None) + load_weights_only: bool = pop_config( + cfg, "load_weights_only", must_exist=False, default_value=False + ) + load_strict_model_weights: bool = pop_config( + cfg, "load_strict_model_weights", must_exist=False, default_value=True + ) + load_ignore_keys: Optional[List[str]] = pop_config( + cfg, "load_ignore_keys", must_exist=False, default_value=None + ) + save_ignore_keys: Optional[List[str]] = pop_config( + cfg, "save_ignore_keys", must_exist=False, default_value=None + ) + compile_config: Optional[Dict[str, Any]] = pop_config( + cfg, "compile_config", must_exist=False, default_value=None + ) + metadata: Optional[Dict[str, str]] = pop_config( + cfg, "metadata", must_exist=False, default_value=None, convert=True + ) + should_log_config: bool = pop_config(cfg, "log_config", must_exist=False, default_value=True) + + num_cpu_threads: Optional[int] = cfg.pop("num_cpu_threads", 0) if num_cpu_threads > 0: - print(f'Setting number of CPU threads to {num_cpu_threads}') - import spops + print(f"Setting number of CPU threads to {num_cpu_threads}") torch.set_num_threads(num_cpu_threads) spops.set_num_threads(num_cpu_threads) # Enable autoresume from model checkpoints if possible autoresume_default: bool = False - if logged_cfg.get('run_name', None) is not None \ - and save_folder is not None \ - and not save_overwrite \ - and not save_weights_only: + if ( + logged_cfg.get("run_name", None) is not None + and save_folder is not None + and not save_overwrite + and not save_weights_only + ): autoresume_default = True - if cfg.get('autoresume') is None and autoresume_default: - log.info('As run_name, save_folder, and save_latest_filename are set, \ - changing autoresume default to True...') + if cfg.get("autoresume") is None and autoresume_default: + log.info( + "As run_name, save_folder, and save_latest_filename are set, \ + changing autoresume default to True..." + ) - autoresume: bool = pop_config(cfg, - 'autoresume', - must_exist=False, - default_value=autoresume_default) + autoresume: bool = pop_config( + cfg, "autoresume", must_exist=False, default_value=autoresume_default + ) # Pop known unused parameters that are used as interpolation variables or # created by update_batch_size_info. - pop_config(cfg, 'data_local', must_exist=False) - pop_config(cfg, 'data_remote', must_exist=False) - pop_config(cfg, 'global_seed', must_exist=False) - pop_config(cfg, 'global_train_batch_size', must_exist=False) - pop_config(cfg, 'n_gpus', must_exist=False) - pop_config(cfg, 'device_train_grad_accum', must_exist=False) + pop_config(cfg, "data_local", must_exist=False) + pop_config(cfg, "data_remote", must_exist=False) + pop_config(cfg, "global_seed", must_exist=False) + pop_config(cfg, "global_train_batch_size", must_exist=False) + pop_config(cfg, "n_gpus", must_exist=False) + pop_config(cfg, "device_train_grad_accum", must_exist=False) - assert fsdp_config is None or ds_config is None, 'fsdp and deepspeed are not supported together' + assert fsdp_config is None or ds_config is None, "fsdp and deepspeed are not supported together" # Warn users for unused parameters for key in cfg: warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.' + f"Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary." ) # Warn if fsdp is enabled but user only has 1 GPU if dist.get_world_size() == 1 and fsdp_config is not None: - warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.') + warnings.warn("FSDP is not applicable for single-GPU training. Reverting to DDP.") fsdp_config = None # set logging level @@ -536,33 +602,32 @@ def main(cfg: DictConfig) -> Trainer: logging.basicConfig( # Example of format string # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + format=f"%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s" ) - logging.getLogger('llmfoundry').setLevel( - python_log_level.upper()) # Foundry module - logging.getLogger(__name__).setLevel( - python_log_level.upper()) # Train script + logging.getLogger("llmfoundry").setLevel(python_log_level.upper()) # Foundry module + logging.getLogger(__name__).setLevel(python_log_level.upper()) # Train script # Initialize context init_context = process_init_device(model_config, fsdp_config) - logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) + logged_cfg.update({"fsdp_config": fsdp_config}, merge=True) # Build tokenizer - log.info('Building tokenizer...') - tokenizer_name = tokenizer_config['name'] - tokenizer_kwargs = tokenizer_config.get('kwargs', {}) + log.info("Building tokenizer...") + tokenizer_name = tokenizer_config["name"] + tokenizer_kwargs = tokenizer_config.get("kwargs", {}) + tokenizer_kwargs["num_proc"] = 1 tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) # Scheduler - scheduler_name: str = scheduler_config.pop('name') + scheduler_name: str = scheduler_config.pop("name") scheduler = build_scheduler(scheduler_name, scheduler_config) # Loggers - loggers = [ - build_logger(str(name), logger_cfg) - for name, logger_cfg in logger_configs.items() - ] if logger_configs else [] + loggers = ( + [build_logger(str(name), logger_cfg) for name, logger_cfg in logger_configs.items()] + if logger_configs + else [] + ) mosaicml_logger = find_mosaicml_logger(loggers) if mosaicml_logger is None: @@ -573,7 +638,7 @@ def main(cfg: DictConfig) -> Trainer: if metadata is not None: # Flatten the metadata for logging - logged_cfg.pop('metadata', None) + logged_cfg.pop("metadata", None) logged_cfg.update(metadata, merge=True) if mosaicml_logger is not None: mosaicml_logger.log_metrics(metadata) @@ -581,84 +646,90 @@ def main(cfg: DictConfig) -> Trainer: # Profiling profiler: Optional[Profiler] = None - profiler_cfg: Optional[DictConfig] = pop_config(cfg, - 'profiler', - must_exist=False, - convert=False, - default_value=None) + profiler_cfg: Optional[DictConfig] = pop_config( + cfg, "profiler", must_exist=False, convert=False, default_value=None + ) if profiler_cfg: - profiler_schedule_cfg: Dict = pop_config(profiler_cfg, - 'schedule', - must_exist=True, - convert=True) + profiler_schedule_cfg: Dict = pop_config( + profiler_cfg, "schedule", must_exist=True, convert=True + ) profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) # Only support json trace handler profiler_trace_handlers: List[TraceHandler] = [] - profiler_trace_cfg: Optional[Dict] = pop_config(profiler_cfg, - 'json_trace_handler', - must_exist=False, - default_value=None, - convert=True) + profiler_trace_cfg: Optional[Dict] = pop_config( + profiler_cfg, "json_trace_handler", must_exist=False, default_value=None, convert=True + ) if profiler_trace_cfg: - profiler_trace_handlers.append( - JSONTraceHandler(**profiler_trace_cfg)) - profiler = Profiler(**profiler_cfg, - trace_handlers=profiler_trace_handlers, - schedule=profiler_schedule) + profiler_trace_handlers.append(JSONTraceHandler(**profiler_trace_cfg)) + profiler = Profiler( + **profiler_cfg, trace_handlers=profiler_trace_handlers, schedule=profiler_schedule + ) # Callbacks - callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) - for name, callback_cfg in callback_configs.items() - ] if callback_configs else [] + callbacks: List[Callback] = ( + [ + build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) + for name, callback_cfg in callback_configs.items() + ] + if callback_configs + else [] + ) use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) - print('ROSA CONFIG', rosa_config) + print("ROSA CONFIG", rosa_config) # Build Model - print('Initializing model...') + print("Initializing model...") with init_context: - assert fsdp_config is None or rosa_config is None, 'fsdp is cuurently not supported with RoSA' - model = build_composer_peft_model(model_config, rosa_config, tokenizer, is_fsdp=fsdp_config is not None) + assert ( + fsdp_config is None or rosa_config is None + ), "fsdp is cuurently not supported with RoSA" + model = build_composer_peft_model( + model_config, rosa_config, tokenizer, is_fsdp=fsdp_config is not None + ) if rosa_config is not None: assert isinstance(model.model.base_model, RosaModel) - + # Algorithms - algorithms = [ - build_algorithm(str(name), algorithm_cfg) - for name, algorithm_cfg in algorithm_configs.items() - ] if algorithm_configs else [] + algorithms = ( + [ + build_algorithm(str(name), algorithm_cfg) + for name, algorithm_cfg in algorithm_configs.items() + ] + if algorithm_configs + else [] + ) if rosa_config is not None: algorithms.append(RosaScheduler(model.model.base_model)) - + # Dataloaders - log.info('Building train loader...') + log.info("Building train loader...") try: + disable_caching() train_loader = build_dataloader( train_loader_config, tokenizer, device_train_batch_size, ) except Exception as e: + print("I am I here") if mosaicml_logger is not None: mosaicml_logger.log_exception(e) raise e if mosaicml_logger is not None: - mosaicml_logger.log_metrics({'data_validated': time.time()}) + mosaicml_logger.log_metrics({"data_validated": time.time()}) ## Evaluation if use_async_eval: evaluators = [] if eval_first: - warnings.warn( - 'AsyncEval callback does not support eval_first=True. Ignoring.' - ) + warnings.warn("AsyncEval callback does not support eval_first=True. Ignoring.") eval_first = False else: - log.info('Building eval loader...') + log.info("Building eval loader...") eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len evaluators, _, eval_gauntlet_callback = build_evaluators( eval_loader_config, @@ -673,79 +744,76 @@ def main(cfg: DictConfig) -> Trainer: callbacks.append(eval_gauntlet_callback) if mosaicml_logger is not None: - log_train_analytics(mosaicml_logger, model_config, train_loader_config, - eval_loader_config, callback_configs, - tokenizer_name, load_path, icl_tasks_config, - eval_gauntlet_config) - # # Build Model - # log.info('Initializing model...') - # model = build_composer_model( - # name=model_config.name, - # cfg=model_config, - # tokenizer=tokenizer, - # init_context=init_context, - # master_weights_dtype=model_config.get('master_weights_dtype', None), - # ) - + log_train_analytics( + mosaicml_logger, + model_config, + train_loader_config, + eval_loader_config, + callback_configs, + tokenizer_name, + load_path, + icl_tasks_config, + eval_gauntlet_config, + ) # Log number of parameters - if hasattr(model, 'n_total_params'): + if hasattr(model, "n_total_params"): n_params = model.n_total_params - n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. + n_trainable_params = n_params # We currently assume all parameters are trainable. else: n_params = sum(p.numel() for p in model.parameters()) - n_trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad) - if hasattr(model, 'n_active_params'): + n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + if hasattr(model, "n_active_params"): n_active_params = model.n_active_params else: n_active_params = n_params - logged_cfg.update({ - 'n_params': n_params, - 'n_active_params': n_active_params, - 'n_trainable_params': n_trainable_params, - }) + logged_cfg.update( + { + "n_params": n_params, + "n_active_params": n_active_params, + "n_trainable_params": n_trainable_params, + } + ) # Optimizer - optimizer_name: str = optimizer_config.pop('name') - if rosa_config is None or 'lora_lr' not in rosa_config: + optimizer_name: str = optimizer_config.pop("name") + if rosa_config is None or "lora_lr" not in rosa_config: optimizer = build_optimizer(model, optimizer_name, optimizer_config) else: print(f'Using a different learning rate for lora params {rosa_config["lora_lr"]}') - assert optimizer_name == 'decoupled_adamw' + assert optimizer_name == "decoupled_adamw" lora_params = [] other_params = [] for name, param in model.named_parameters(): - if any([k in name for k in ['rosa_A', 'rosa_B', 'rosa_embedding_A', 'rosa_embedding_B']]): + if any( + [k in name for k in ["rosa_A", "rosa_B", "rosa_embedding_A", "rosa_embedding_B"]] + ): lora_params.append(param) else: other_params.append(param) - print(f'Found {len(lora_params)} lora params and {len(other_params)} other params') - params = [ - {'params': other_params}, - {'params': lora_params, 'lr': rosa_config['lora_lr']} - ] + print(f"Found {len(lora_params)} lora params and {len(other_params)} other params") + params = [{"params": other_params}, {"params": lora_params, "lr": rosa_config["lora_lr"]}] optimizer = DecoupledAdamW(params, **optimizer_config) - - # Now add the eval metrics try: if eval_loader_config is not None and not use_async_eval: eval_metrics = model.get_metrics(is_train=False) non_icl_metrics = [ - metric_name for metric_name, metric in eval_metrics.items() + metric_name + for metric_name, metric in eval_metrics.items() if not isinstance(metric, InContextLearningMetric) ] - evaluators = add_metrics_to_eval_loaders(evaluators, - non_icl_metrics) + evaluators = add_metrics_to_eval_loaders(evaluators, non_icl_metrics) except Exception as e: if mosaicml_logger is not None: mosaicml_logger.log_exception(e) raise e # Build the Trainer - log.info('Building trainer...') + log.info("Building trainer...") + dtypes = {x.dtype for x in model.parameters()} + print(dtypes) trainer = Trainer( run_name=run_name, seed=seed, @@ -787,7 +855,7 @@ def main(cfg: DictConfig) -> Trainer: ) if should_log_config: - log.info('Logging config') + log.info("Logging config") log_config(logged_cfg) torch.cuda.empty_cache() gc.collect() @@ -796,55 +864,35 @@ def main(cfg: DictConfig) -> Trainer: if eval_first and trainer.state.timestamp.batch.value == 0: trainer.eval() - log.info('Starting training...') + log.info("Starting training...") trainer.fit() - # if rosa is enabled, save the model manually, since + # Hacky solution for moving the model checkpoint from the + # subdirectory that the HF writer wrote it into, and into + # our desired and expected location. Only needed for full + # (not low-rank) finetuning. + if rosa_config is None and torch.distributed.get_rank() == 0: + path_to_save = os.path.join(hf_save_path, run_name) + hf_output_path = os.path.join(path_to_save, "huggingface") + for filename in glob.glob(os.path.join(hf_output_path, "*", "*")): + shutil.copy(filename, path_to_save) + shutil.rmtree(os.path.join(hf_output_path)) + + # if rosa is enabled, save the model manually, since # llm-foundry's checkpointing doesn't work properly with RoSA if rosa_config is not None: - assert fsdp_config is None, 'fsdp is cuurently not supported with RoSA' + assert fsdp_config is None, "fsdp is cuurently not supported with RoSA" path_to_save = os.path.join(hf_save_path, run_name) - print(f'saving the model to {path_to_save}') + print(f"saving the model to {path_to_save}") if torch.distributed.get_rank() == 0: - model.model.save_pretrained(path_to_save, is_main_process=True, state_dict=model.model.state_dict()) + model.model.save_pretrained( + path_to_save, is_main_process=True, state_dict=model.model.state_dict() + ) tokenizer.save_pretrained(path_to_save) - # print('Saving directly into HF-friendly format') - - # path_to_save = os.path.join(hf_save_path, run_name) - # print('saving the model.') - # if fsdp_config is None: - # model.model.save_pretrained(path_to_save, is_main_process=torch.distributed.get_rank() == 0, state_dict=model.model.state_dict()) - # else: - # with FSDP.summon_full_params(model.model, writeback=False, rank0_only=True, offload_to_cpu=True): - # model_to_save = model.model - # model_to_save.save_pretrained(path_to_save, state_dict=model_to_save.state_dict()) - - # if torch.distributed.get_rank() == 0: - # tokenizer.save_pretrained(path_to_save) - - # # NOTE: for some reason the saving code above would create empty pytorch_model.bin file, so we delete it manually - # # TODO: figure out why this happens - # if torch.distributed.get_rank() == 0 and os.path.exists(os.path.join(path_to_save, "pytorch_model.bin")): - # tmp = torch.load(os.path.join(path_to_save, "pytorch_model.bin")) - # if not tmp: # empty dict, remove it - # os.remove(os.path.join(path_to_save, "pytorch_model.bin")) - - log.info('Done.') + log.info("Done.") return trainer -if __name__ == '__main__': - yaml_path, args_list = sys.argv[1], sys.argv[2:] - - # Disable resolving environment variables through omegaconf. - om.clear_resolver('oc.env') - - # Load yaml and cli arguments. - with open(yaml_path) as f: - yaml_cfg = om.load(f) - cli_cfg = om.from_cli(args_list) - cfg = om.merge(yaml_cfg, cli_cfg) - om.resolve(cfg) - assert isinstance(cfg, DictConfig) - main(cfg) +if __name__ == "__main__": + main() diff --git a/src/panza/interface/__init__.py b/src/panza/interface/__init__.py new file mode 100644 index 0000000..3af9505 --- /dev/null +++ b/src/panza/interface/__init__.py @@ -0,0 +1,5 @@ +from .cli import PanzaCLI +from .gui import PanzaGUI +from .json import PanzaJSON + +__all__ = ["PanzaCLI", "PanzaGUI", "PanzaJSON"] diff --git a/src/panza/interface/cli.py b/src/panza/interface/cli.py new file mode 100644 index 0000000..43e5656 --- /dev/null +++ b/src/panza/interface/cli.py @@ -0,0 +1,15 @@ +from panza.entities.instruction import EmailInstruction, Instruction +from panza.writer import PanzaWriter + + +class PanzaCLI: + def __init__(self, writer: PanzaWriter, **kwargs): + self.writer = writer + while True: + user_input = input("Enter a command: ") + if user_input == "exit": + break + else: + instruction: Instruction = EmailInstruction(user_input) + stream = self.writer.run(instruction, stream=True) + stream.end() diff --git a/src/panza/interface/gui.py b/src/panza/interface/gui.py new file mode 100644 index 0000000..185c278 --- /dev/null +++ b/src/panza/interface/gui.py @@ -0,0 +1,30 @@ +from panza.entities.instruction import EmailInstruction, Instruction +from panza.writer import PanzaWriter +import gradio as gr + + +class PanzaGUI: + def __init__(self, writer: PanzaWriter, **kwargs): + self.writer = writer + with gr.Blocks() as panza: + gr.Markdown("# Panza\n") + inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") + outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") + inputbox.submit( + self.get_execute(), + [inputbox], + [outputbox], + ) + + panza.queue().launch(server_name="localhost", server_port=5002, share=True) + + def get_execute(self): + def execute(input): + instruction: Instruction = EmailInstruction(input) + stream = self.writer.run(instruction, stream=True) + output = "" + for chunk in stream: + output += chunk + yield output + + return execute diff --git a/src/panza/interface/gui_b.py b/src/panza/interface/gui_b.py new file mode 100644 index 0000000..62ab1db --- /dev/null +++ b/src/panza/interface/gui_b.py @@ -0,0 +1,31 @@ +from panza.entities.instruction import EmailInstruction, Instruction +from panza.writer import PanzaWriter +import gradio as gr + + +class PanzaGUI: + def __init__(self, writer: PanzaWriter, **kwargs): + self.writer = writer + with gr.Blocks() as panza: + gr.Markdown("# Panza\n") + inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") + outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") + inputbox.submit( + self.get_execute(), + [inputbox], + [outputbox], + ) + + panza.queue().launch(server_name="localhost", server_port=5003, share=True) + + def get_execute(self): + def execute(input): + instruction: Instruction = EmailInstruction(input) + stream = self.writer.run(instruction, stream=False) + # output = "" + # for chunk in stream: + # output += chunk + # yield stream.end() + yield stream + + return execute diff --git a/src/panza/interface/json.py b/src/panza/interface/json.py new file mode 100644 index 0000000..4169c12 --- /dev/null +++ b/src/panza/interface/json.py @@ -0,0 +1,203 @@ +from panza.entities.instruction import EmailInstruction +from panza.writer import PanzaWriter + +import json +import numpy as np +import os +import re +from tqdm import tqdm + +from evaluate import load +from torchmetrics.text.bleu import BLEUScore +from torchmetrics.text.rouge import ROUGEScore +import string +import nltk + +# Ensure that tokenizer has been downloaded to ensure script does not fail. +try: + nltk.find("tokenizers/punkt_tab") +except: + print("punkt_tab was not downloaded. Installing.") + nltk.download("punkt_tab") + +punc_table = str.maketrans({key: None for key in string.punctuation}) +rouge = ROUGEScore() +bleu1 = BLEUScore(n_gram=1) +bleu2 = BLEUScore(n_gram=2) +bleu3 = BLEUScore(n_gram=3) +bleu4 = BLEUScore(n_gram=4) +mauve = load("mauve") + + +def compute_rouge_scores(predictions, goldens): + goldens = [" ".join(x.translate(punc_table).lower().split()) for x in goldens] + candidates = [ + " ".join(prediction.translate(punc_table).lower().split()) for prediction in predictions + ] + scores = [ + {k: v.item() for k, v in rouge(candidate, goldens).items()} for candidate in candidates + ] + return scores + + +def compute_bleu_scores(predictions, goldens): + goldens = [" ".join(x.translate(punc_table).lower().split()) for x in goldens] + candidates = [ + " ".join(prediction.translate(punc_table).lower().split()) for prediction in predictions + ] + bleu_scores = [ + np.mean([bleu([candidate], [goldens]) for bleu in [bleu1, bleu2, bleu3, bleu4]]) + for candidate in candidates + ] + return [s.item() for s in bleu_scores] + + +def compute_mauve_score(predictions, goldens): + predictions = [ + prediction for nested_prediction in predictions for prediction in nested_prediction + ] + goldens = [golden for nested_golden in goldens for golden in nested_golden] + mauve_score = mauve.compute(predictions=predictions, references=goldens) + return mauve_score + + +class PanzaJSON: + def compose_output_folder(self, json_path, checkpoint, panza_workspace, username): + if os.path.isdir(checkpoint): + # Presumably this is a Panza-trained model; go ahead + # and put the json output into the same folder. + output_dir = checkpoint + else: + # Assume that this is a huggingface model identified by its hf handle. + # We don't want to populate the cached model folder, so instead + # we create a folder in the Panza workspace to put the output. + output_dir = os.path.join( + panza_workspace, "checkpoints", "models", checkpoint, username + ) + os.makedirs(output_dir, exist_ok=True) + filename_no_ext = os.path.splitext(os.path.basename(json_path))[0] + return os.path.join(output_dir, f"{filename_no_ext}_outputs.json") + + def assemble_responses(self, prompts_json, batch_size, use_thread, responses_per_prompt): + + with open(prompts_json, "r") as f: + golden_lines = [json.loads(l) for l in f.readlines()] + + # Group json lines together by prompt to avoid weirdness in + # eval metric computation. In case golden responses are provided, + # all goldens are used as alternatives for BLEU and ROUGE scores; + # the first one provided is used for MAUVE. + grouped_golden = {} + has_goldens = False + for entry in golden_lines: + # 'summary' is the name of the 'prompt' field, i.e., the one to group on. + if entry["summary"] in grouped_golden: + if "email" in entry: + has_goldens = True + grouped_golden[entry["summary"]]["goldens"].append(entry["email"]) + else: + grouped_golden[entry["summary"]] = {} + if "email" in entry: + has_goldens = True + grouped_golden[entry["summary"]]["goldens"] = [(entry["email"])] + grouped_golden[entry["summary"]]["thread"] = entry["thread"] + # Convert dict to list of (k, v) pairs to batch through it. + grouped_golden = list(grouped_golden.items()) + + all_responses = [] + for i in tqdm(range(0, len(grouped_golden), batch_size)): + batch = grouped_golden[i : i + batch_size] + prompts = [item[0] for item in batch] + if use_thread: + threads = [item[1]["thread"] for item in batch] + golden_responses = [item[1]["goldens"] for item in batch] + + responses = [ + { + "prompt": p, + "full_prompt": None, + "thread": None if not use_thread else threads[i], + "golden_responses": golden_responses[i], + "panza_responses": [], + } + for i, p in enumerate(prompts) + ] + for _ in range(responses_per_prompt): + if use_thread: + instructions = list(zip(prompts, threads)) + else: + instructions = list(zip(prompts, [[]] * len(prompts))) + + outputs, full_prompts = self.writer.run_batch( + [ + EmailInstruction(user_input[0], thread=user_input[1]) + for user_input in instructions + ], + return_prompt=True, + ) + + # Remove some boilerplate added by instruction-tuned models w/out finetuning. + outputs = [o.replace("Here is the email:\n", "") for o in outputs] + outputs = [re.sub(r"SUBJECT:.*\n", "", o) for o in outputs] + outputs = [re.sub(r"Subject:.*\n", "", o) for o in outputs] + outputs = [re.sub(r"E-MAIL CONTENT:.*\n", "", o) for o in outputs] + + for i, r in enumerate(responses): + r["full_prompt"] = full_prompts[i] + r["panza_responses"].append(outputs[i]) + all_responses += responses + return all_responses, has_goldens + + def do_compute_metrics(self, all_responses): + for response in all_responses: + response["scores"] = {} + response["scores"]["BLEU"] = compute_bleu_scores( + response["panza_responses"], response["golden_responses"] + ) + response["scores"]["ROUGE"] = compute_rouge_scores( + response["panza_responses"], response["golden_responses"] + ) + rouge_categories = all_responses[0]["scores"]["ROUGE"][0].keys() + aggregate_metrics = { + "BLEU": np.mean([s for r in all_responses for s in r["scores"]["BLEU"]]), + "ROUGE": { + cat: np.mean([s[cat] for r in all_responses for s in r["scores"]["ROUGE"]]) + for cat in rouge_categories + }, + "MAUVE": compute_mauve_score( + [r["panza_responses"] for r in all_responses], + [r["golden_responses"] for r in all_responses], + ).mauve, + } + print("########## Aggregated quality metrics ##########\n") + print(json.dumps(aggregate_metrics, indent=2)) + return {"responses": all_responses, "aggregate_metrics": aggregate_metrics} + + def __init__( + self, + writer: PanzaWriter, + checkpoint: str, + panza_workspace: str, + input_file: str, + batch_size: int, + use_thread: bool, + responses_per_prompt: int, + compute_metrics: bool, + username: str, + ): + self.writer = writer + responses, has_goldens = self.assemble_responses( + input_file, batch_size, use_thread, responses_per_prompt + ) + if compute_metrics: + if has_goldens: + responses = self.do_compute_metrics(responses) + else: + print( + "Warning: metrics requested but no golden labels given!", + "\nDumping responses without computing metrics.", + ) + + output_path = self.compose_output_folder(input_file, checkpoint, panza_workspace, username) + with open(output_path, "w") as f: + json.dump(responses, f, indent=4, sort_keys=True) diff --git a/src/panza/llm/__init__.py b/src/panza/llm/__init__.py new file mode 100644 index 0000000..477919b --- /dev/null +++ b/src/panza/llm/__init__.py @@ -0,0 +1,4 @@ +from .base import LLM, ChatHistoryType, MessageType +from .local import PeftLLM, TransformersLLM + +__all__ = ["LLM", "ChatHistoryType", "MessageType", "TransformersLLM", "PeftLLM"] diff --git a/src/panza/llm/base.py b/src/panza/llm/base.py new file mode 100644 index 0000000..6293881 --- /dev/null +++ b/src/panza/llm/base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import Dict, Iterator, List, Literal + +MessageType = Dict[Literal["role", "content"], str] +ChatHistoryType = List[MessageType] + + +class LLM(ABC): + def __init__(self, name: str, sampling_parameters: Dict): + self.name = name + self.sampling_parameters = sampling_parameters + + @abstractmethod + def chat(self, messages: ChatHistoryType | List[ChatHistoryType]) -> List[str]: + pass + + @abstractmethod + def chat_stream(self, messages: ChatHistoryType) -> Iterator[str]: + pass diff --git a/src/panza/llm/local.py b/src/panza/llm/local.py new file mode 100644 index 0000000..f887e39 --- /dev/null +++ b/src/panza/llm/local.py @@ -0,0 +1,168 @@ +from abc import abstractmethod +from typing import Any, Dict, Iterator, List, Type + +import torch + +_MISSING_LIBRARIES = [] + +try: + from peft import AutoPeftModelForCausalLM +except ImportError: + AutoPeftModelForCausalLM = None + _MISSING_LIBRARIES.append("peft") + +try: + from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer +except ImportError: + AutoModelForCausalLM = None + AutoTokenizer = None + _MISSING_LIBRARIES.append("transformers") + +try: + from transformers import BitsAndBytesConfig +except ImportError: + BitsAndBytesConfig = None + _MISSING_LIBRARIES.append("bitsandbytes") + + +from .base import LLM, ChatHistoryType + + +class LocalLLM(LLM): + def __init__( + self, + name: str, + checkpoint: str, + device: str, + sampling_parameters: Dict, + dtype: str, + load_in_4bit: bool, + ): + self._check_installation() + + super().__init__(name, sampling_parameters) + self.checkpoint = checkpoint + self.device = device + + assert dtype in [None, "fp32", "bf16"] + if device == "cpu": + assert dtype == "fp32", "CPU only supports fp32, please specify --dtype fp32" + dtype = None if dtype is None else (torch.float32 if dtype == "fp32" else torch.bfloat16) + self.dtype = dtype + + self.load_in_4bit = load_in_4bit + self.quantization_config = ( + BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=dtype, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + if load_in_4bit + else None + ) + + self._load_model_and_tokenizer() + + def chat(self, messages: ChatHistoryType | List[ChatHistoryType]) -> List[str]: + encodeds = self.tokenizer.apply_chat_template( + messages, + return_tensors="pt", + add_generation_prompt=True, + padding=True, + truncation=True, + return_dict=True, + ) + model_inputs = encodeds.to(self.device) + + generated_ids = self.model.generate( + **model_inputs, + **self.sampling_parameters, + pad_token_id=self.tokenizer.pad_token_id, + ) + + prompt_length = encodeds["input_ids"].shape[1] + outputs = self.tokenizer.batch_decode( + generated_ids[:, prompt_length:], skip_special_tokens=True + ) + + return outputs + + def chat_stream(self, messages: ChatHistoryType) -> Iterator[str]: + if isinstance(messages[0], (list, tuple)) or hasattr(messages[0], "messages"): + raise TypeError("chat_stream does not support batched messages.") + + streamer = TextStreamer(self.tokenizer) + encodeds = self.tokenizer.apply_chat_template( + messages, + return_tensors="pt", + add_generation_prompt=True, + padding=True, + truncation=True, + return_dict=True, + ) + model_inputs = encodeds.to(self.device) + + self.model.generate( + **model_inputs, + streamer=streamer, + **self.sampling_parameters, + pad_token_id=self.tokenizer.pad_token_id, + ) + + return streamer + + def _check_installation(self) -> None: + if AutoModelForCausalLM is None or AutoTokenizer is None: + raise ImportError( + "transformers is not installed. Please install it with `pip install transformers`." + ) + + if BitsAndBytesConfig is None: + from transformers import __version__ as version + + raise ImportError( + f"transformers {version} does not support 4-bit quantization. Please upgrade to a newer version." + ) + + def _load_model_and_tokenizer_with_constructor(self, model_class: Type[Any]) -> None: + if self.load_in_4bit: + self.model = model_class.from_pretrained( + self.checkpoint, + device_map=self.device, + quantization_config=self.quantization_config, + trust_remote_code=True, + ) + else: + self.model = model_class.from_pretrained( + self.checkpoint, + torch_dtype=self.dtype, + device_map=self.device, + trust_remote_code=True, + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + self.checkpoint, model_max_length=self.model.config.max_position_embeddings + ) + self.tokenizer.padding_side = "left" + self.tokenizer.pad_token = self.tokenizer.eos_token + + @abstractmethod + def _load_model_and_tokenizer(self) -> None: + pass + + +class TransformersLLM(LocalLLM): + def _load_model_and_tokenizer(self): + self._load_model_and_tokenizer_with_constructor(AutoModelForCausalLM) + + +class PeftLLM(LocalLLM): + def _check_installation(self) -> None: + super()._check_installation() + if AutoPeftModelForCausalLM is None: + raise ImportError("peft is not installed.") + + def _load_model_and_tokenizer(self) -> None: + self._load_model_and_tokenizer_with_constructor(AutoPeftModelForCausalLM) + self.model = self.model.merge_and_unload() diff --git a/src/panza/prompting/__init__.py b/src/panza/prompting/__init__.py new file mode 100644 index 0000000..ab79ffd --- /dev/null +++ b/src/panza/prompting/__init__.py @@ -0,0 +1,5 @@ +from .base import PromptBuilder +from .email_prompting import EmailPromptBuilder +from .summarization_prompting import SummarizationPromptBuilder + +__all__ = ["PromptBuilder", "EmailPromptBuilder", "SummarizationPromptBuilder"] diff --git a/src/panza/prompting/base.py b/src/panza/prompting/base.py new file mode 100644 index 0000000..713b5f1 --- /dev/null +++ b/src/panza/prompting/base.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from ..entities import Instruction +from ..retriever import DocumentRetriever + + +class PromptBuilder(ABC): + def __init__(self, retriever: DocumentRetriever): + self.retriever = retriever + + @abstractmethod + def build_prompt(self, instruction: Instruction) -> str: + pass diff --git a/src/panza/prompting/email_prompting.py b/src/panza/prompting/email_prompting.py new file mode 100644 index 0000000..3a0e227 --- /dev/null +++ b/src/panza/prompting/email_prompting.py @@ -0,0 +1,138 @@ +from typing import List, Tuple + +from ..entities import Email, EmailInstruction +from ..retriever import DocumentRetriever +from .base import PromptBuilder +from .utils import load_preamble, load_user_preamble + + +class EmailPromptBuilder(PromptBuilder): + def __init__( + self, + retriever: DocumentRetriever, + system_preamble: str, + user_preamble: str, + rag_preamble: str, + thread_preamble: str, + number_rag_emails: int, + rag_relevance_threshold: float, + number_thread_emails: int, + ): + self.retriever = retriever + self.system_preamble = system_preamble + self.user_preamble = user_preamble + self.rag_preamble = rag_preamble + self.thread_preamble = thread_preamble + self.number_rag_emails = number_rag_emails + self.rag_relevance_threshold = rag_relevance_threshold + self.number_thread_emails = number_thread_emails + + self.retriever.set_document_class(Email) + + def _create_rag_preamble_from_emails(self, emails: List[Email]) -> str: + rag_context = self._create_rag_context_from_emails(emails) + return self.rag_preamble.format(rag_context=rag_context) + + def _create_rag_context_from_emails(self, emails: List[Email]) -> str: + """Creates a RAG context from a list of relevant e-mails. + + The e-mails are formatted as follows: + + E-MAIL CONTENT: + + + --- + + E-MAIL CONTENT: + + + --- + ... + """ + + rag_context = "" + for email in emails: + rag_context += f"E-MAIL CONTENT:\n{email.email}\n\n---\n\n" + + return rag_context + + def _create_threading_preamble(self, thread: List[str]) -> str: + threading_context = self._create_threading_context(thread) + return self.thread_preamble.format(threading_context=threading_context) + + def _create_threading_context(self, thread: List[str]) -> str: + """Creates a threading context from a list of relevant e-mails. + + The e-mails are formatted as follows: + + + + --- + + + + --- + ... + """ + + threading_context = "" + for email in thread: + threading_context += f"{email}\n\n---\n\n" + + return threading_context + + @staticmethod + def load_all_preambles( + system_preamble_path: str, + user_preamble_path: str, + rag_preamble_path: str, + thread_preamble_path: str, + ) -> Tuple[str, str, str, str]: + """Load all preambles from file.""" + system_preamble = load_preamble(system_preamble_path) if system_preamble_path else "" + user_preamble = load_user_preamble(user_preamble_path) if user_preamble_path else "" + rag_preamble = load_preamble(rag_preamble_path) if rag_preamble_path else "" + thread_preamble = load_preamble(thread_preamble_path) if thread_preamble_path else "" + return system_preamble, user_preamble, rag_preamble, thread_preamble + + def build_prompt( + self, + instruction: EmailInstruction, + ) -> str: + + if self.number_thread_emails and not self.rag_preamble: + raise ValueError("RAG preamble format must be provided if RAG is used.") + + if self.number_thread_emails and not self.thread_preamble: + raise ValueError("Thread preamble format must be provided if thread is used.") + + if self.number_rag_emails > 0: + relevant_emails = self.retriever.retrieve( + instruction.instruction, self.number_rag_emails, self.rag_relevance_threshold + ) + rag_prompt = self._create_rag_preamble_from_emails(relevant_emails).strip() + else: + rag_prompt = "" + + if self.number_thread_emails > 0: + thread_prompt = self._create_threading_preamble( + instruction.thread[: self.number_thread_emails] + ).strip() + else: + thread_prompt = "" + + system_preamble = self.system_preamble.strip() + user_preamble = self.user_preamble.strip() + + prompt = "" + if system_preamble: + prompt += f"{system_preamble}\n\n" + if user_preamble: + prompt += f"{user_preamble}\n\n" + if rag_prompt: + prompt += f"{rag_prompt}\n\n" + if thread_prompt: + prompt += f"{thread_prompt}\n\n" + prompt += f"Instruction: {instruction.instruction}" + + return prompt diff --git a/src/panza/prompting/summarization_prompting.py b/src/panza/prompting/summarization_prompting.py new file mode 100644 index 0000000..e1e28fb --- /dev/null +++ b/src/panza/prompting/summarization_prompting.py @@ -0,0 +1,19 @@ +from ..entities import SummarizationInstruction +from .base import PromptBuilder + + +class SummarizationPromptBuilder(PromptBuilder): + def __init__( + self, + summarization_prompt: str, + ): + self.summarization_prompt = summarization_prompt + + def build_prompt( + self, + instruction: SummarizationInstruction, + ) -> str: + + prompt = self.summarization_prompt.format(email=instruction.instruction).strip() + + return prompt diff --git a/src/panza/prompting/utils.py b/src/panza/prompting/utils.py new file mode 100644 index 0000000..ec3daee --- /dev/null +++ b/src/panza/prompting/utils.py @@ -0,0 +1,19 @@ +def load_preamble(path: str) -> str: + with open(path, "r") as file: + return file.read().strip() + + +def load_user_preamble(path: str) -> str: + # The user preamble must be edited by the user in order to work as intended. + # Here, we perform additional checks to make sure that that happened; if not, + # We issue a warning to the user. + with open(path, "r") as file: + lines = [l for l in file.readlines() if not l.strip().startswith("#")] + preamble = "".join(lines) + if "CHANGE ME" in preamble: + print( + "*" * 66 + + "\n* WARNING: User prompt preamble not customized. *\n* Please edit the preamble at prompt_preambles/user_preamble.txt *\n" + + "*" * 66 + ) + return preamble diff --git a/src/panza/retriever/__init__.py b/src/panza/retriever/__init__.py new file mode 100644 index 0000000..780a4ca --- /dev/null +++ b/src/panza/retriever/__init__.py @@ -0,0 +1,5 @@ +from .base import DocumentRetriever +from .faiss import FaissRetriever +from .none import NoneRetriever + +__all__ = ["DocumentRetriever", "FaissRetriever", "NoneRetriever"] diff --git a/src/panza/retriever/base.py b/src/panza/retriever/base.py new file mode 100644 index 0000000..d0e3e37 --- /dev/null +++ b/src/panza/retriever/base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple + +from ..entities.document import Document + + +class DocumentRetriever(ABC): + @abstractmethod + def retrieve(self, query: str, k: int, score: Optional[float] = None) -> List[Document]: + pass + + @abstractmethod + def retrieve_with_score( + self, query: str, k: int, score: Optional[float] = None + ) -> List[Tuple[Document, float]]: + pass + + @abstractmethod + def store(self, documents: List[Document]): + pass + + @abstractmethod + def save_db_to_disk(self): + pass + + def set_document_class(self, document_class: type[Document]): + self.document_class = document_class diff --git a/src/panza/retriever/faiss.py b/src/panza/retriever/faiss.py new file mode 100644 index 0000000..898851c --- /dev/null +++ b/src/panza/retriever/faiss.py @@ -0,0 +1,97 @@ +import logging +from typing import List, Optional, Tuple + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore + +from ..entities.document import Document +from .base import DocumentRetriever + +LOGGER = logging.getLogger(__name__) + + +class FaissRetriever(DocumentRetriever): + def __init__( + self, + db_path: str, + index_name: str, + embedding_model: str, + device: str, + document_class: Optional[type[Document]] = None, + ) -> None: + + self.db_path = db_path + self.index_name = index_name + self.model_name = embedding_model + self.device = device + self.document_class = document_class + + self.embedding_model = self._get_embeddings_model(self.model_name, self.device) + self.db = self._load_vector_db_from_disk( + self.db_path, self.index_name, self.embedding_model + ) + + def _get_embeddings_model(self, model_name: str, device: str) -> Embeddings: + embeddings_model = HuggingFaceEmbeddings( + model_name=model_name, + model_kwargs={"device": device}, + encode_kwargs={"normalize_embeddings": False}, + ) + return embeddings_model + + def _load_vector_db_from_disk( + self, db_path: str, index_name: str, embeddings_model: Embeddings + ) -> VectorStore: + try: + db = FAISS.load_local( + folder_path=db_path, + embeddings=embeddings_model, + index_name=index_name, + allow_dangerous_deserialization=True, # Allows pickle deserialization + ) + LOGGER.info(f"Loaded Faiss index {index_name} from {db_path}.") + return db + except Exception as e: + LOGGER.error(f"Failed to load Faiss index {index_name} from {db_path}. Error: {e}") + + def retrieve(self, query: str, k: int, score: Optional[float] = None) -> List[Document]: + results = self.retrieve_with_score(query, k, score) + results = [r[0] for r in results] + return results + + def retrieve_with_score( + self, query: str, k: int, score: Optional[float] = None + ) -> List[Tuple[Document, float]]: + + results = self.db._similarity_search_with_relevance_scores(query, k=k) + + # Filter by score + if score is not None: + results = [r for r in results if r[1] >= score] + + # Deserialize metadata + results = [ + (self.document_class.deserialize(r[0].metadata["serialized_email"]), r[1]) + for r in results + ] + + return results + + def store(self, documents: List[Document], chunk_size: int, chunk_overlap: int): + documents = self.document_class.process( + documents=documents, chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + db = FAISS.from_documents(documents, self.embedding_model) + + if self.db: + self.db.merge_from(db) + else: + LOGGER.info(f"Creating new Faiss index {self.index_name} in {self.db_path}.") + self.db = db + + def save_db_to_disk(self): + # Save vector DB to disk + self.db.save_local(folder_path=self.db_path, index_name=self.index_name) + logging.info(f"Vector DB index {self.index_name} saved to {self.db_path}.") diff --git a/src/panza/retriever/none.py b/src/panza/retriever/none.py new file mode 100644 index 0000000..f310da0 --- /dev/null +++ b/src/panza/retriever/none.py @@ -0,0 +1,29 @@ +import logging +from typing import List, Optional, Tuple + +from ..entities.document import Document +from .base import DocumentRetriever + +LOGGER = logging.getLogger(__name__) + + +class NoneRetriever(DocumentRetriever): + def __init__( + self, + document_class: Optional[type[Document]] = None, + ) -> None: + self.document_class = document_class + + def retrieve(self, query: str, k: int, score: Optional[float] = None) -> List[Document]: + return [] + + def retrieve_with_score( + self, query: str, k: int, score: Optional[float] = None + ) -> List[Tuple[Document, float]]: + return [] + + def store(self, documents: List[Document], chunk_size: int, chunk_overlap: int): + pass + + def save_db_to_disk(self): + pass diff --git a/src/panza/utils/prompting.py b/src/panza/utils/prompting.py index 6096b79..ed3ec9f 100644 --- a/src/panza/utils/prompting.py +++ b/src/panza/utils/prompting.py @@ -36,9 +36,7 @@ def create_prompt( if thread_emails: assert thread_preamble, "Thread preamble format must be provided if thread is provided." - thread_prompt = _create_threading_preamble( - thread_preamble, thread_emails - ).strip() + thread_prompt = _create_threading_preamble(thread_preamble, thread_emails).strip() else: thread_prompt = "" @@ -85,14 +83,15 @@ def _create_rag_context_from_emails(emails: List[Email]) -> Text: rag_context = "" for email in emails: - rag_context += f"SUBJECT: {email.subject}\n" f"E-MAIL CONTENT:\n{email.email}\n\n---\n\n" + rag_context += ( + # f"SUBJECT: {email.metadata['subject']}\n" # TODO(armand): Handle subject metadata + f"E-MAIL CONTENT:\n{email.page_content}\n\n---\n\n" + ) return rag_context -def _create_threading_preamble( - threading_preamble_format: Text, thread: List[Text] -) -> Text: +def _create_threading_preamble(threading_preamble_format: Text, thread: List[Text]) -> Text: threading_context = _create_threading_context(thread) return threading_preamble_format.format(threading_context=threading_context) diff --git a/src/panza/writer.py b/src/panza/writer.py new file mode 100644 index 0000000..312e4a3 --- /dev/null +++ b/src/panza/writer.py @@ -0,0 +1,43 @@ +from typing import Iterator, List, Tuple + +from .entities import Instruction +from .llm import LLM, MessageType +from .prompting import PromptBuilder + + +class PanzaWriter: + def __init__(self, prompt_builder: PromptBuilder, llm: LLM): + self.prompt_builder = prompt_builder + self.llm = llm + + def run( + self, instruction: Instruction, stream: bool = False, return_prompt: bool = False + ) -> str | Iterator[str] | Tuple[str, str] | Tuple[Iterator[str], str]: + prompt = self.prompt_builder.build_prompt(instruction) + messages = self._create_user_message(content=prompt) + + if stream: + response = self.llm.chat_stream(messages) + else: + response = self.llm.chat(messages)[0] + + if return_prompt: + return response, prompt + else: + return response + + def run_batch( + self, instructions: List[Instruction], return_prompt: bool = False + ) -> List[str] | Tuple[List[str], List[str]]: + prompts = [self.prompt_builder.build_prompt(instruction) for instruction in instructions] + messages = [self._create_user_message(content=prompt) for prompt in prompts] + + response = self.llm.chat(messages) + + if return_prompt: + return response, prompts + else: + return response + + def _create_user_message(self, content: str) -> MessageType: + return [{"role": "user", "content": content}] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1531ba2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,91 @@ +from datetime import datetime +from pathlib import Path + +import pytest + +from panza.entities import Email +from panza.retriever import FaissRetriever + + +@pytest.fixture +def embedding_model() -> str: + return "sentence-transformers/all-mpnet-base-v2" + + +@pytest.fixture +def generative_model() -> str: + return "microsoft/Phi-3-mini-4k-instruct" + + +@pytest.fixture +def peft_model() -> str: + return "microsoft/Phi-3-mini-4k-instruct" + + +@pytest.fixture +def index_name() -> str: + return "test-index" + + +@pytest.fixture(scope="function") +def faiss_db_path(tmp_path: Path, index_name: str, embedding_model: str) -> Path: + # Create a new temporary directory for each test + base_temp_dir = tmp_path / "data" + base_temp_dir.mkdir() # Ensure the data directory is created + + # Define the mock emails + emails = [ + Email(email=f"email{i}", subject=f"subject{i}", thread=[f"thread{i}"], date=datetime.now()) + for i in range(3) + ] + + # Initialize the FaissRetriever + retriever = FaissRetriever( + db_path=base_temp_dir, + index_name=index_name, + embedding_model=embedding_model, + device="cpu", + document_class=Email, + ) + + # Store the mock emails in the vector database + retriever.store(emails, chunk_size=1000, chunk_overlap=1000) + retriever.save_db_to_disk() + + # Return the path to the directory containing all mock data + return base_temp_dir + + +@pytest.fixture +def preambles_path(tmp_path: Path) -> Path: + preambles_path = tmp_path / "prompt_preambles" + preambles_path.mkdir(parents=True) + return preambles_path + + +@pytest.fixture +def system_preamble_path(preambles_path) -> Path: + system_preamble_path = preambles_path / "system_preamble.txt" + system_preamble_path.write_text("") + return system_preamble_path + + +@pytest.fixture +def user_preamble_path(preambles_path) -> Path: + user_preamble_path = preambles_path / "user_preamble.txt" + user_preamble_path.write_text("") + return user_preamble_path + + +@pytest.fixture +def rag_preamble_path(preambles_path) -> Path: + rag_preamble_path = preambles_path / "rag_preamble.txt" + rag_preamble_path.write_text("RAG PREAMBLE:\n\n{rag_context}") + return rag_preamble_path + + +@pytest.fixture +def thread_preamble_path(preambles_path) -> Path: + thread_preamble_path = preambles_path / "thread_preamble.txt" + thread_preamble_path.write_text("THREAD PREAMBLE:\n\n{threading_context}") + return thread_preamble_path diff --git a/tests/test_entities.py b/tests/test_entities.py new file mode 100644 index 0000000..18935c1 --- /dev/null +++ b/tests/test_entities.py @@ -0,0 +1,42 @@ +import json +from datetime import datetime + +import pytest + +from panza.entities import Email, EmailInstruction + + +def test_email_serialization_deserialization(): + email = Email(email="email", subject="subject", thread=["thread"], date=datetime.now()) + serialized = json.dumps(email.serialize()) + deserialized = Email.deserialize(serialized) + assert email == deserialized + + +def test_email_processing(): + email = Email(email="email", subject="subject", thread=["thread"], date=datetime.now()) + processed = Email.process([email], chunk_size=1000, chunk_overlap=1000) + assert processed[0].page_content == email.email + + deserialized = Email.deserialize(processed[0].metadata["serialized_document"]) + assert email == deserialized + + +def test_email_instruction_init(): + instruction = EmailInstruction(instruction="Write an email.") + assert instruction.instruction == "Write an email." + assert instruction.thread == [] + assert instruction.past_messages == [] + + instruction = EmailInstruction( + instruction="Write an email.", + thread=["thread"], + past_messages=[{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hi!"}], + ) + + assert instruction.instruction == "Write an email." + assert instruction.thread == ["thread"] + assert instruction.past_messages == [ + {"role": "user", "content": "Hi!"}, + {"role": "assistant", "content": "Hi!"}, + ] diff --git a/tests/test_local_llm.py b/tests/test_local_llm.py new file mode 100644 index 0000000..3d21b16 --- /dev/null +++ b/tests/test_local_llm.py @@ -0,0 +1,116 @@ +from typing import Type + +import pytest +from torch import float32 as torch_float32 + +from panza.llm import PeftLLM, TransformersLLM +from panza.llm.local import _MISSING_LIBRARIES, LocalLLM + +skip_if_no_transformers = pytest.mark.skipif( + "transformers" in _MISSING_LIBRARIES, reason="transformers is not installed" +) +skip_if_no_peft = pytest.mark.skipif("peft" in _MISSING_LIBRARIES, reason="peft is not installed") +skip_if_no_bitsandbytes = pytest.mark.skipif( + "bitsandbytes" in _MISSING_LIBRARIES, reason="bitsandbytes is not installed" +) + + +@pytest.mark.parametrize( + "local_llm_class, checkpoint", + [ + pytest.param( + TransformersLLM, "microsoft/Phi-3-mini-4k-instruct", marks=skip_if_no_transformers + ), + # TODO: Replace local Peft checkpoint with fixture + pytest.param( + PeftLLM, + "/nfs/scistore19/alistgrp/Checkpoints/Panza/shared/armand/models/test_rosa_checkpoint", + marks=[skip_if_no_transformers, skip_if_no_peft], + ), + ], +) +def test_local_llm_init(local_llm_class: Type[LocalLLM], checkpoint: str): + model = local_llm_class( + name="local_llm", + checkpoint=checkpoint, + device="cpu", + sampling_parameters={"do_sample": False, "max_new_tokens": 50}, + dtype="fp32", + load_in_4bit=False, + ) + assert model is not None + assert model.name == "local_llm" + assert model.checkpoint == checkpoint + assert model.model is not None + assert model.tokenizer is not None + assert model.model.device.type == "cpu" + assert model.dtype == torch_float32 + assert model.model.dtype == model.dtype + + +@pytest.mark.parametrize( + "local_llm_class, checkpoint", + [ + pytest.param( + TransformersLLM, "microsoft/Phi-3-mini-4k-instruct", marks=skip_if_no_transformers + ), + # TODO: Replace local Peft checkpoint with fixture + pytest.param( + PeftLLM, + "/nfs/scistore19/alistgrp/Checkpoints/Panza/shared/armand/models/test_rosa_checkpoint", + marks=[skip_if_no_transformers, skip_if_no_peft], + ), + ], +) +def test_local_llm_generate(local_llm_class: Type[LocalLLM], checkpoint: str): + model = local_llm_class( + name="local_llm", + checkpoint=checkpoint, + device="cpu", + sampling_parameters={"do_sample": False, "max_new_tokens": 50}, + dtype="fp32", + load_in_4bit=False, + ) + + messages = [{"role": "user", "content": "Write something."}] + + outputs = model.chat(messages) + + assert outputs is not None + assert len(outputs) == 1 + + +@pytest.mark.parametrize( + "local_llm_class, checkpoint", + [ + pytest.param( + TransformersLLM, "microsoft/Phi-3-mini-4k-instruct", marks=skip_if_no_transformers + ), + # TODO: Replace local Peft checkpoint with fixture + pytest.param( + PeftLLM, + "/nfs/scistore19/alistgrp/Checkpoints/Panza/shared/armand/models/test_rosa_checkpoint", + marks=[skip_if_no_transformers, skip_if_no_peft], + ), + ], +) +def test_local_llm_generate_batch(local_llm_class: Type[LocalLLM], checkpoint: str): + model = local_llm_class( + name="local_llm", + checkpoint=checkpoint, + device="cpu", + sampling_parameters={"do_sample": False, "max_new_tokens": 50}, + dtype="fp32", + load_in_4bit=False, + ) + + messages = [ + [{"role": "user", "content": "Write something."}], + [{"role": "user", "content": "Write something else."}], + [{"role": "user", "content": "Write something different."}], + ] + + outputs = model.chat(messages) + + assert outputs is not None + assert len(outputs) == 3 diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 0000000..af0ec74 --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,67 @@ +from datetime import datetime +from pathlib import Path + +import pytest + +from panza.entities import Email +from panza.retriever import FaissRetriever + + +def get_faiss_retriever( + db_path: Path, index_name: str, embedding_model: str, device: str +) -> FaissRetriever: + retriever = FaissRetriever( + db_path=db_path, + index_name=index_name, + embedding_model=embedding_model, + device=device, + ) + retriever.set_document_class(Email) + return retriever + + +def test_faiss_retriever_init_empty(tmp_path: Path, index_name: str, embedding_model: str): + retriever = get_faiss_retriever(tmp_path, index_name, embedding_model, "cpu") + assert retriever is not None + assert retriever.embedding_model is not None + assert retriever.db is None + + +def test_faiss_retriever_init_existing(faiss_db_path: Path, index_name: str, embedding_model: str): + retriever = get_faiss_retriever(faiss_db_path, index_name, embedding_model, "cpu") + assert retriever is not None + assert retriever.embedding_model is not None + assert retriever.db is not None + + +def test_faiss_retriever_store_over_empty(tmp_path: Path, index_name: str, embedding_model: str): + retriever = get_faiss_retriever(tmp_path, index_name, embedding_model, "cpu") + + emails = [ + Email(email=f"email{i}", subject=f"subject{i}", thread=[f"thread{i}"], date=datetime.now()) + for i in range(3) + ] + + retriever.store(emails, chunk_size=1000, chunk_overlap=1000) + assert retriever.db is not None + + +def test_faiss_retriever_store_over_existing( + faiss_db_path: Path, index_name: str, embedding_model: str +): + retriever = get_faiss_retriever(faiss_db_path, index_name, embedding_model, "cpu") + assert retriever.db is not None + + number_existing_documents = len(retriever.db.index_to_docstore_id) + assert number_existing_documents != 0 + + number_new_documents = 3 + emails = [ + Email(email=f"email{i}", subject=f"subject{i}", thread=[f"thread{i}"], date=datetime.now()) + for i in range(number_new_documents) + ] + + retriever.store(emails, chunk_size=1000, chunk_overlap=1000) + + number_total_documents = len(retriever.db.index_to_docstore_id) + assert number_total_documents == number_existing_documents + number_new_documents diff --git a/tests/test_writer.py b/tests/test_writer.py new file mode 100644 index 0000000..e8e6100 --- /dev/null +++ b/tests/test_writer.py @@ -0,0 +1,31 @@ +from unittest.mock import MagicMock + +import pytest + +from panza.entities import EmailInstruction +from panza.llm import LLM +from panza.prompting import EmailPromptBuilder +from panza.writer import PanzaWriter + + +def test_email_writer(): + # Create mock prompt builder + mock_builder = MagicMock(spec=EmailPromptBuilder) + mock_builder.build_prompt.side_effect = ( + lambda instruction: f"Instruction: {instruction.instruction}" + ) + + # Create mock LLM + mock_llm = MagicMock(spec=LLM) + mock_llm.chat.side_effect = lambda messages: [f"Received: {messages[0]['content']}"] + + panza_writer = PanzaWriter(mock_builder, mock_llm) + + instruction = EmailInstruction(instruction="Write an email.") + + output = panza_writer.run(instruction) + assert output == "Received: Instruction: Write an email." + + output, prompt = panza_writer.run(instruction, return_prompt=True) + assert output == "Received: Instruction: Write an email." + assert prompt == "Instruction: Write an email."