Skip to content

Commit 5243e02

Browse files
authored
Merge pull request #157 from ChEB-AI/fix/load-electra-checkpoints
fix checkpoint loading for electra, return attentions
2 parents bcf3839 + 3fb8442 commit 5243e02

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

chebai/models/electra.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
203203
)
204204
* CLS_TOKEN
205205
)
206+
model_kwargs["output_attentions"] = True
206207
return dict(
207208
features=torch.cat((cls_tokens, batch.x), dim=1),
208209
labels=batch.y,

chebai/result/prediction.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,18 @@ def __init__(
5858
self._model_hparams = ckpt_file["hyper_parameters"]
5959
self._model_hparams.pop("_instantiator", None)
6060
self._model_hparams.pop("classes_txt_file_path", None)
61-
self._model = ChebaiBaseNet.load_from_checkpoint(
62-
checkpoint_path, map_location=self.device
63-
)
61+
try:
62+
self._model = ChebaiBaseNet.load_from_checkpoint(
63+
checkpoint_path,
64+
map_location=self.device,
65+
)
66+
except Exception:
67+
# models trained on a pretrained checkpoint have an additional path argument that we need to set to None
68+
self._model = ChebaiBaseNet.load_from_checkpoint(
69+
checkpoint_path,
70+
map_location=self.device,
71+
pretrained_checkpoint=None,
72+
)
6473
assert (
6574
isinstance(self._model, ChebaiBaseNet)
6675
and type(self._model) is not ChebaiBaseNet

0 commit comments

Comments
 (0)