Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/deep_impact/models/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel
from nltk.stem import PorterStemmer

from src.utils.checkpoint import ModelCheckpoint


class DeepImpact(BertPreTrainedModel):
max_length = 512
tokenizer = tokenizers.Tokenizer.from_pretrained('bert-base-uncased')
tokenizer = tokenizers.Tokenizer.from_pretrained('mixedbread-ai/mxbai-embed-large-v1')
tokenizer.enable_truncation(max_length)
punctuation = set(string.punctuation)
stemmer = PorterStemmer()

stemmer_cache = {}

def __init__(self, config):
super(DeepImpact, self).__init__(config)
Expand Down Expand Up @@ -97,15 +101,20 @@ def get_query_document_token_mask(cls, query_terms: Set[str], term_to_token_inde
mask = np.zeros(max_length, dtype=bool)
token_indices_of_matching_terms = [v for k, v in term_to_token_index.items() if k in query_terms]
mask[token_indices_of_matching_terms] = True

return torch.from_numpy(mask)

@classmethod
def process_query(cls, query: str) -> Set[str]:
query = cls.tokenizer.normalizer.normalize_str(query)
return set(filter(lambda x: x not in cls.punctuation,
map(lambda x: x[0], cls.tokenizer.pre_tokenizer.pre_tokenize_str(query))))

terms = map(lambda x: x[0], cls.tokenizer.pre_tokenizer.pre_tokenize_str(query))
filtered_terms = filter(lambda x: x not in cls.punctuation, terms)
stemmed_terms = set()
for term in filtered_terms:
if term not in cls.stemmer_cache:
cls.stemmer_cache[term] = cls.stemmer.stem(term)
stemmed_terms.add(cls.stemmer_cache[term])
return stemmed_terms

@classmethod
def process_document(cls, document: str) -> Tuple[tokenizers.Encoding, Dict[str, int]]:
"""
Expand All @@ -116,7 +125,7 @@ def process_document(cls, document: str) -> Tuple[tokenizers.Encoding, Dict[str,

document = cls.tokenizer.normalizer.normalize_str(document)
document_terms = [x[0] for x in cls.tokenizer.pre_tokenizer.pre_tokenize_str(document)]

encoded = cls.tokenizer.encode(document_terms, is_pretokenized=True)

term_index_to_token_index = {}
Expand All @@ -135,12 +144,16 @@ def process_document(cls, document: str) -> Tuple[tokenizers.Encoding, Dict[str,
if term not in filtered_term_to_token_index \
and term not in cls.punctuation \
and i in term_index_to_token_index:
# check if stemm is cached
if term not in cls.stemmer_cache:
cls.stemmer_cache[term] = cls.stemmer.stem(term)
term = cls.stemmer_cache[term]
filtered_term_to_token_index[term] = term_index_to_token_index[i]
return encoded, filtered_term_to_token_index

@classmethod
def load(cls, checkpoint_path: Optional[Union[str, Path]] = None):
model = cls.from_pretrained('Luyu/co-condenser-marco')
model = cls.from_pretrained('mixedbread-ai/mxbai-embed-large-v1')
if checkpoint_path is not None:
if os.path.exists(checkpoint_path):
ModelCheckpoint.load(model=model, last_checkpoint_path=checkpoint_path)
Expand Down
10 changes: 8 additions & 2 deletions src/deep_impact/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from src.utils.datasets import MSMarcoTriples, DistillationScores
from src.deep_impact.evaluation.nano_beir_evaluator import NanoBEIREvaluator

import wandb

def collate_fn(batch, model_cls=DeepImpact, max_length=None):
encoded_list, masks = [], []
Expand Down Expand Up @@ -100,8 +101,12 @@ def run(
in_batch_negatives: bool = False,
start_with: Union[str, Path] = None,
qrels_path: Union[str, Path] = None,
eval_every: int = 500
eval_every: int = 500,
use_wandb: bool = False,
):
if use_wandb:
wandb.init(project="deep-impact")

# DeepImpact
model_cls = DeepImpact
trainer_cls = Trainer
Expand Down Expand Up @@ -172,6 +177,7 @@ def run(
gradient_accumulation_steps=gradient_accumulation_steps,
evaluator=evaluator,
eval_every=eval_every,
use_wandb=use_wandb,
)
trainer.train()
trainer_cls.ddp_cleanup()
Expand Down Expand Up @@ -199,7 +205,7 @@ def run(
parser.add_argument("--in_batch_negatives", action="store_true", help="Use in-batch negatives")
parser.add_argument("--start_with", type=Path, default=None, help="Start training with this checkpoint")
parser.add_argument("--eval_every", type=int, default=500, help="Evaluate every n steps")

parser.add_argument("--use_wandb", action="store_true", help="Use wandb")

# required for distillation loss with Margin MSE
parser.add_argument("--qrels_path", type=Path, default=None, help="Path to the qrels file")
Expand Down
20 changes: 17 additions & 3 deletions src/deep_impact/training/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from pathlib import Path
from typing import Union
import wandb

import torch
import torch.distributed
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
gradient_accumulation_steps: int = 1,
eval_every: int = 500,
evaluator: BaseEvaluator = None,
use_wandb: bool = False,
) -> None:
self.seed = seed
self.gpu_id = torch.distributed.get_rank()
Expand All @@ -41,6 +43,7 @@ def __init__(
self.gradient_accumulation_steps = gradient_accumulation_steps
self.eval_every = eval_every
self.evaluator = evaluator
self.use_wandb = use_wandb

model_name = self.model.__class__.__name__
last_checkpoint_path = (checkpoint_dir /
Expand Down Expand Up @@ -70,6 +73,8 @@ def __init__(
self.checkpoint_dir = checkpoint_dir
self.checkpoint_callback.batch_size = self.batch_size * self.n_ranks
self.model = DDP(self.model, device_ids=[self.gpu_id], find_unused_parameters=True)
if self.use_wandb:
wandb.watch(self.model, log="all")
self.criterion = torch.nn.CrossEntropyLoss()

def train(self):
Expand All @@ -92,14 +97,18 @@ def train(self):

for i, batch in enumerate(self.train_data):
with torch.cuda.amp.autocast():
outputs = self.get_output_scores(batch)
outputs = self.get_output_scores(batch, i)
loss = self.evaluate_loss(outputs, batch)

loss /= self.gradient_accumulation_steps

scaler.scale(loss).backward()
current_loss = loss.detach().cpu().item()
train_loss += current_loss

if self.use_wandb:
wandb.log({"train_loss": current_loss}, step=i)
wandb.log({"avg_score": outputs.mean().item(), "min_score": outputs.min().item(), "max_score": outputs.max().item()}, step=i)

if i % self.gradient_accumulation_steps == 0:
scaler.unscale_(self.optimizer)
Expand All @@ -109,10 +118,13 @@ def train(self):
self.optimizer.zero_grad()

if self.gpu_id == 0:
if i % self.eval_every == 0 and self.evaluator is not None:
if i % self.eval_every == 0 and i > 0 and self.evaluator is not None:
self.logger.info(f"Evaluating NanoBEIR at iteration {i}")
metrics = self.evaluator.evaluate_all(self.model.module)
self.logger.info(f"Metrics: {metrics}")
if self.use_wandb:
for dataset, dataset_metrics in metrics.items():
wandb.log({f"{dataset}_ndcg@10": dataset_metrics[0]["NDCG@10"]}, step=i)
# write metrics to file as as single line, add iteration number
with open(self.checkpoint_dir / "metrics.txt", "a") as f:
f.write(json.dumps({"iteration": i, "metrics": metrics}) + "\n")
Expand All @@ -134,11 +146,13 @@ def get_input_tensors(self, encoded_list):
type_ids = torch.tensor([x.type_ids for x in encoded_list], dtype=torch.long).to(self.gpu_id)
return input_ids, attention_mask, type_ids

def get_output_scores(self, batch):
def get_output_scores(self, batch, step):
input_ids, attention_mask, type_ids = self.get_input_tensors(batch['encoded_list'])
document_term_scores = self.model(input_ids, attention_mask, type_ids)

masks = batch['masks'].to(self.gpu_id)
if self.use_wandb:
wandb.log({"avg_matches": (masks != 0).float().mean().item()}, step=step)
return (masks * document_term_scores).sum(dim=1).squeeze(-1).view(self.batch_size, -1)

def evaluate_loss(self, outputs, batch):
Expand Down