Skip to content

Adding support for the granite multilingual embeddings R2 (ibm-granite/granite-embedding-{97,311}m-multilingual-r2 models)#22716

Open
hansolosan wants to merge 4 commits intoggml-org:masterfrom
primeqa:granite-embedding-r2-pr
Open

Adding support for the granite multilingual embeddings R2 (ibm-granite/granite-embedding-{97,311}m-multilingual-r2 models)#22716
hansolosan wants to merge 4 commits intoggml-org:masterfrom
primeqa:granite-embedding-r2-pr

Conversation

@hansolosan
Copy link
Copy Markdown

  • modern-bert: support SwiGLU FFN for Granite Embedding R2
  • Update: Add support for "granite-embed-r2" in hash matching, vocab pre-types, and tokenizer configurations

Overview

The PR adds support for 2 granite multilingual models just released, based on the ModernBERT architecture. Support is added to link the tokenizers properly and to use a different activation function for the 97m model (SiLU/SwiGLU) instead of the regular GeGLU.

Additional information

The models are available here: https://huggingface.co/ibm-granite/granite-embedding-97m-multilingual-r2 and https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2. In retrieval scores, the 97m is 8 points better than the next model on the MMTEB leaderboard under 100M parameters, and the 311m model is the second one in the <500M parameters category.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes - I used Claude Code to identify the places where to make the changes, then examined the code, made sure it fit with the policy, and made sure it works as expected.

I am not an AI agent :).

ModernBert derivatives such as IBM Granite Embedding multilingual R2
(97m / 311m) use SiLU/SwiGLU in the FFN instead of the original
GELU/GeGLU. Persist the hidden activation in the GGUF and select
LLM_FFN_SWIGLU vs LLM_FFN_GEGLU at graph build time. Also register
the Granite R2 tokenizers so the converter recognizes them as
modern-bert.
@hansolosan hansolosan requested review from CISC and ggerganov as code owners May 5, 2026 13:48
@hansolosan hansolosan marked this pull request as draft May 5, 2026 13:49
@github-actions github-actions Bot added model Model specific python python script changes labels May 5, 2026
@hansolosan hansolosan force-pushed the granite-embedding-r2-pr branch from 34541a7 to 4f283cf Compare May 5, 2026 15:16
@hansolosan hansolosan marked this pull request as ready for review May 5, 2026 15:22
@hansolosan
Copy link
Copy Markdown
Author

@gabe-l-hart here is the PR.

@gabe-l-hart gabe-l-hart self-requested a review May 5, 2026 15:35
@gabe-l-hart
Copy link
Copy Markdown
Collaborator

Thanks @hansolosan! I'll take a first pass review in the next day or two and notify maintainers once we're ready for final review.

Copy link
Copy Markdown
Collaborator

@gabe-l-hart gabe-l-hart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to make the hparam more flexible for future models that need it.

Comment thread src/llama-hparams.h Outdated
// FFN gated activation flavor (used by ModernBert/derivatives that may use
// SwiGLU instead of the default GeGLU). The graph for those archs reads
// this to pick LLM_FFN_SWIGLU vs LLM_FFN_GEGLU.
bool ffn_act_swiglu = false;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Most model-specific hparams live towards the bottom of the field declaration. This effects how structs are initialized and while this repo doesn't ever use direct initialization for hparams, other tools that use this header (yes that violates encapsulation, but it's the internet), might.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less-NIT: In the GGUF, this looks like it's represented as a string, but here it's a bool which limits the future usability. I think it would be cleaner to use llm_ffn_op_type (declared in llama-graph.h, so available here). This would also avoid the need for the ternary above.

If we go that route, we could also align the name as ffn_op. Further even, we could add a helper in llama-graph.* to do the enum <-> string mapping so it's centralized and reusable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Gabe - looks like llama-graph.h includes this file (llama-hparams.h) - we could move the field llm_ffn_op_type to llama-arch.h and have it included from here.

…pt-4o tokenizer, changed that.

- the pretokenizer regex for gpt-4o had a bug (exposed in Arabic) - added the marks marker \p{M} in the lookup
@gabe-l-hart
Copy link
Copy Markdown
Collaborator

I've confirmed that the inference is working as intended. Here was my process:

Conversion

(cd ~/models && hf download ibm-granite/granite-embedding-97m-multilingual-r2 --local-dir ibm-granite/granite-embedding-97m-multilingual-r2)
python convert_hf_to_gguf.py ~/models/ibm-granite/granite-embedding-97m-multilingual-r2/

Baseline w/ Sentence Transformers

I used this script to compare the results of running with sentence-transformers

granite_embed.py
from sentence_transformers import SentenceTransformer
import numpy as np
import subprocess
import shlex
import sys

model_path = "/Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2"
lcpp_model = f"{model_path}/granite-embedding-97M-multilingual-r2-BF16.gguf"
lcpp_exe = "./build/bin/llama-embedding"

if len(sys.argv) > 1:
    model_path = sys.argv[1]
if len(sys.argv) > 2:
    lcpp_model = sys.argv[2]
if len(sys.argv) > 3:
    lcpp_exe = sys.argv[3]

model = SentenceTransformer(model_path)

input_queries = [
    "hello world",
    "tell me a story about a developer and their dog",
    "123sfg this is a r@nd0m t35t",
]


def cosine_similarity(vector_a: np.ndarray, vector_b: np.ndarray) -> float:
    vector_a = np.asarray(vector_a)
    vector_b = np.asarray(vector_b)
    numerator = np.dot(vector_a, vector_b)
    denominator_a = np.linalg.norm(vector_a)
    denominator_b = np.linalg.norm(vector_b)
    if denominator_a == 0 or denominator_b == 0: return 0.0
    cosine_sim = numerator / (denominator_a * denominator_b)
    return cosine_sim


for query in input_queries:
    print("### BASELINE ###")
    embedding = model.encode([query])
    print("Embedding shape:", embedding.shape)
    print("Embedding vector:", embedding[:, :8])

    print("### llama.cpp ###")
    cmd = f"{lcpp_exe} -m {lcpp_model} -p \"{query}\" --temp 0 --embd-normalize -1"
    print(f"llama.cpp command: {cmd}")
    proc = subprocess.Popen(
        shlex.split(cmd),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    out, _ = proc.communicate()
    vals = out.decode("utf-8").split(":")[-1]
    vals = [
        float(v) for v in vals.split()
        if v.strip()
    ]
    lcpp_emb = np.array(vals)
    print("llama.cpp Embedding shape:", lcpp_emb.shape)
    print("llama.cpp Embedding vector:", lcpp_emb[:8])
    print()
    cos_sim = cosine_similarity(embedding, lcpp_emb)
    print(f"COSINE SIMILARITY: {cos_sim}")
    print("--------------------------------")
    print()

Results w/out branch

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[-6.8227053e-03  3.0905887e-02  2.8057294e-02  1.5359418e-02
  -3.0930677e-02 -8.4969969e-03 -2.0503732e-05 -9.3344050e-03]]
### llama.cpp ###
llama.cpp command: /Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2/granite-embedding-97M-multilingual-r2-BF16.gguf -p "hello world" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 1.551675 -1.943748  1.574718  3.077239 -2.425059  1.19634   0.595976
  4.560798]

COSINE SIMILARITY: [0.86851926]
--------------------------------

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[-0.02954087 -0.00377955  0.02121077  0.00609067  0.02561346 -0.05363345
  -0.01472757  0.01316541]]
### llama.cpp ###
llama.cpp command: /Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2/granite-embedding-97M-multilingual-r2-BF16.gguf -p "tell me a story about a developer and their dog" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [-1.42778   1.695506  3.480772  1.437287  2.130465 -7.187495 -0.347658
  1.865673]

COSINE SIMILARITY: [0.88419644]
--------------------------------

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[ 0.00780747 -0.05262671 -0.01048512 -0.01660965  0.05743939 -0.07795463
   0.02905622  0.06451879]]
### llama.cpp ###
llama.cpp command: /Users/ghart/Projects/github/ggml-org/llama.cpp/build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2/granite-embedding-97M-multilingual-r2-BF16.gguf -p "123sfg this is a r@nd0m t35t" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [-1.698663 -3.29584   3.333025  0.741073  1.758265 -1.212622  2.644687
  3.909934]

COSINE SIMILARITY: [0.78144405]
--------------------------------

Results w/ branch

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[-6.8227053e-03  3.0905887e-02  2.8057294e-02  1.5359418e-02
  -3.0930677e-02 -8.4969969e-03 -2.0503732e-05 -9.3344050e-03]]
### llama.cpp ###
llama.cpp command: ./build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2/granite-embedding-97M-multilingual-r2-BF16.gguf -p "hello world" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [-5.428760e-01  2.462277e+00  2.234345e+00  1.223225e+00 -2.462488e+00
 -6.793040e-01 -2.037000e-03 -7.409680e-01]

COSINE SIMILARITY: [0.99999996]
--------------------------------

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[-0.02954087 -0.00377955  0.02121077  0.00609067  0.02561346 -0.05363345
  -0.01472757  0.01316541]]
### llama.cpp ###
llama.cpp command: ./build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2/granite-embedding-97M-multilingual-r2-BF16.gguf -p "tell me a story about a developer and their dog" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [-2.253743 -0.283001  1.613361  0.461304  1.966426 -4.101535 -1.122563
  0.998517]

COSINE SIMILARITY: [0.99999785]
--------------------------------

### BASELINE ###
Embedding shape: (1, 384)
Embedding vector: [[ 0.00780747 -0.05262671 -0.01048512 -0.01660965  0.05743939 -0.07795463
   0.02905622  0.06451879]]
### llama.cpp ###
llama.cpp command: ./build/bin/llama-embedding -m /Users/ghart/models/ibm-granite/granite-embedding-97m-multilingual-r2/granite-embedding-97M-multilingual-r2-BF16.gguf -p "123sfg this is a r@nd0m t35t" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 0.624031 -4.109614 -0.825582 -1.296852  4.491943 -6.104836  2.267448
  5.04938 ]

COSINE SIMILARITY: [0.99999723]
--------------------------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants