Skip to content

[WIP][Fix] GLM 5 set apply_rotary_pos_emb to is_neox_style=False && remove F.relu()#45017

Open
JaredforReal wants to merge 8 commits intohuggingface:mainfrom
JaredforReal:fix/glm5
Open

[WIP][Fix] GLM 5 set apply_rotary_pos_emb to is_neox_style=False && remove F.relu()#45017
JaredforReal wants to merge 8 commits intohuggingface:mainfrom
JaredforReal:fix/glm5

Conversation

@JaredforReal
Copy link
Copy Markdown
Contributor

@JaredforReal JaredforReal commented Mar 26, 2026

What does this PR do?

Get the rope operation right

Before: NeoX split-half style
After: GPT-J/interleaved style(interleaved=True same as is_neox_style=Flase) the right one

Get rid of F.relu

Reason:

  • F.relu works with act_quant and rotate_activation for BF16 to make index more accurate with nums quantized to FP8
  • When we use GLM-5-FP8, with no act_quant and rotate_activation invlove, adding F.relu to index.score would make the output not reasonable (see below)

Fixes PPL test

Test Example

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.elastic.multiprocessing.errors import record

@record
def main():
    model_id = "/workspace/glm5-0210-fp8"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        attn_implementation="eager",
    )

    prompt = "Explain the MTP in LLM"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=1.0,
            top_p=0.95
        )

    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

if __name__ == "__main__":
    main()

Before:

(transformers) root@develop-20260325221219-my0s7:~/transformers# python test.py
Loading weights: 100%|███████████████████████████████████████████████████████| 2559/2559 [02:03<00:00, 20.70it/s]
The following generation flags are not valid and may be ignored: ['top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Failed to load CUTLASS quantization kernel: You need to specify `allow_all_kernels=True` to use kernels outside of the `kernels-community` repository. Falling back to Triton.
Explain the MTP in LLM 5.1
Question

Answers

Answer 

Related Questions

Answers

Answer :

MTP stands for MTP stands for "Minimum Efficient Training Program." It is a widely use a popular text-based method which is to convert any text into other n text and a proce

Answers:

Answers

Question:

Answers

Answers

Answer: I want to convertion't is the m text into an to used for

Answer :

Answers

Answers

Answer :
Answer:
Answers
Answers: "MTP in L
What is not text. 5 answer"

The MTP stands for Answer:
MTP is

Question: Answers

Answers

Answer : MTP.  to MTP stands for convert answear:

Question . This: M for in M

Answers

Answers for Answer MTP is MTP in MTP"

\

Answers

A for "M is a popular

Answers

Answer MTP, which MTP for to " MTP?

Answer" the the MTP is the process of  'MT"MTP in is answestion"

Answers

Answers

Answer. stand  to  for " Answers"

Answers"Answer :Answers"

Answers
Answers: to the in MTP 2 Answers

\begin{Answers" : M for Answers:  M for MTP in M" is an

Answers" is"

Answer. 1.  MTP
Answers

Answers

Question:
Answers

Answers in

Question :

\{ 'MTP

Answers 'MTP is a

Answers"  the following: MTP in  to
Answers

Answers" : M

Answers "MTP in

Answer "Answers", question"

\section

Answer"

Question: 5 \begin{TP in MTP in
\{MTA: MTP"

AnswerTP is
 is a widely used to
QuestionTP in : "Answers
 "Answer: TP stands for "What does MTP

Answers" stands for \textMT

Answers for "Answers"

:
Answers"

Answers is
 to Answers. 1. MTP. 9.
MTP is 'Answer"
Answers" stands for \ for a
Answer 1. 5. 9. **: MTP.

\ Answer: MTPL answers:
Answers

Answers1. M
Answers in
Answers, "T stands for Answer

Answers "Answers. M for" (in a popular. You'll

Answer:TP. A is for "Text-protocol. The "Answer"

Answers:MT stands for "Answers"

Question is Answers
. The full form of" 3, the of MTP: The expansion of : 'M is
: to MTP

Question \begin for 2.
I. M-Answers: a short for 7. MTP

\  a" is M to TP

Answers:TP "AnswerM

\section* " in a "text: M for \MT \text \{Question" which is an acronym
Answer" is the expansion for 'What does

Answers

:

Answer.9. M. " stands \{MT" is

Answers: a question"
Answer's "MTP Answers: stands for standsMTM

Answer: Answers

Answers? "T
Answers: T for\(\mathbf for 'MTP".

\section*{M is "MTP.

\Question.

Answer" is 'MTP"
\text \: M for \: \MT" to " stands for text

Answers
 the acronym MTP?Answers. \ is
Answers of \: M \MT

Answer the
 of.2.

: M is "MT for "the M for  \section
Answers stands for \{Answer"
\section"

Answer " of "Answers is 'Answer

Answer
MT P". The MTP is for \ is "Answers" isMT"9.

In: MTP stands "Text" is

Answers"MT"

: text\

AnswersMT PAnswersMT P" is a popular method to 'Answers in
Answers: to
Question is

\MT in
\Question 2. A 'Answers" to 'Answers to the
Answer:AnswersMT" is 'MT is a popular programming and
\section*

Question

Question 'Answers".

\ "MT"

Answers. \ in MTP in

Answer is popular protocol used to 'MTP, and is the
Question

Question:What does MT' 'MTP in protocol used to T.2

Question: M Answers for
: This is \ is Answers stands for "Answers in
 is: MTP".

The above is. T
Question for MTP acronym which stands for "Message Transfer Protocol

\section

Answers" stands \MT

Answers M is used in
\begin{section. For" protocol that's MT

Answer"MT\ is
.  MTP is a M
Answer:TP for
I

After:

(transformers) root@develop-20260325221219-my0s7:~/transformers# python test.py
Loading weights: 100%|███████████████████████████████████████████████████████| 2559/2559 [02:08<00:00, 19.91it/s]
The following generation flags are not valid and may be ignored: ['top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Explain the MTP in LLM?
In the field of deep learning, especially in autoregressive generation tasks (such as language model pre-training), MTP (Multi-Token Prediction) is a new training objective that aims to improve the model's reasoning ability and generation efficiency by predicting multiple future tokens simultaneously. The core idea of MTP is that in each training step, the model is not limited to predicting only the next token (the traditional Next Token Prediction), but expands to predict the next \(n\) tokens, thus enabling the model to acquire longer-term dependencies and deeper reasoning capabilities.
Technical Implementation of MTP
1. Shared Representation and Independent Output Heads:
MTP typically shares the backbone of the model (such as a Transformer) to extract intermediate features of the current context and then uses multiple independent output heads (linear layers + softmax) to predict the next 1st, 2nd, ... up to the nth token separately. This design ensures that the model does not significantly increase computational overhead while learning multi-step predictions.
2. Training Objective:
The loss function of MTP is the sum of cross-entropy losses for all prediction heads. For each training sample, the model needs to minimize the following objective:
\[
\mathcal{L}_{\text{MTP}} = \sum_{i=1}^{n} \text{CrossEntropy}(p_i, y_i)
\]
where \(p_i\) is the predicted probability distribution for the \(i\)-th future token, and \(y_i\) is the corresponding true label.
3. Parallel Decoding Support:
During inference, MTP can be used for parallel decoding. For example, by using methods like Blockwise Parallel Decoding, the model can predict multiple tokens at once and then verify them, thus accelerating the generation process.
Advantages of MTP
1. Improved Reasoning Capability:
MTP forces the model to plan multiple steps in advance during training, thus better capturing the structure of language and logical reasoning chains. Experiments show that MTP performs significantly better than traditional single-step prediction on tasks such as mathematical reasoning and code generation.
2. Accelerated Inference:
During inference, MTP can be used for parallel decoding, reducing the number of autoregressive steps. For example, by predicting 4 tokens at a time, the theoretical inference speed can be increased by 3 times.
3. Better Uncertainty Estimation:
By predicting multiple future tokens, the model can more accurately estimate the uncertainty of generation, which is important for tasks such as active learning and model calibration.
Challenges of MTP
1. Increased Training Complexity:
MTP requires the model to optimize multiple objectives simultaneously, which may lead to optimization difficulties, such as gradient conflicts or overfitting.
2. Higher Demand for Data Quality:
MTP relies on high-quality data to learn long-term dependencies, and noisy data may exacerbate error accumulation.
3. Limited Performance Improvement in Some Tasks:
For tasks that rely more on local patterns (such as simple text classification), the improvement of MTP may not be obvious.
Typical Applications
1. Large Language Models:
MTP has been applied to models such as DeepSeek-V3, achieving significant improvements on reasoning tasks.
2. Code Generation:
MTP can generate code structures more efficiently and improve the syntax correctness rate.
3. Mathematical Reasoning:
MTP performs particularly well on tasks such as GSM8K and MATH, proving its effectiveness in complex reasoning.
Summary
MTP is a training objective with great potential, which significantly improves the model's reasoning ability and inference efficiency by expanding the prediction range. However, its practical application still needs to address challenges such as training complexity and data quality. In the future, with the improvement of optimization algorithms and the enhancement of computing power, MTP is expected to become one of the standard techniques for large language model training. MTP (Multi-Token Prediction) has broad application prospects in scenarios such as mathematics, code, and logic, but in tasks like "writing novels," the improvement may not be significant, or may even backfire. Next, I will explain the reasons in detail, analyze the problems encountered when MTP is applied to novel writing, and propose improvement solutions.
I. Why MTP performs well in mathematics, code, and logic?
These tasks have a high degree of determinism, clear logic, and high frequency of local patterns. For example, mathematical derivation, code syntax, and logical inference can all be predicted in advance through "multi-step planning." MTP enables the model to learn the pattern of "thinking several steps ahead," making it more targeted during training, resulting in significant performance improvements.
II. Why MTP is not effective (or may backfire) in novel writing?
1. Novel writing emphasizes "creativity" rather than "determinism":
The core of a novel lies in the plot, emotional changes, and suspense, all of which are highly uncertain. MTP will force the model to "plan the next few sentences in advance," which may lead to:
- Stereotyped plots: The model tends to choose
(transformers) root@develop-20260325221219-my0s7:~/transformers#

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copilot AI review requested due to automatic review settings March 26, 2026 09:21
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the GLM-MoE-DSA (GLM-5) rotary position embedding application to support interleaved (GPT‑J style) RoPE and removes a ReLU nonlinearity from the DSA indexer’s score computation.

Changes:

  • Extend apply_rotary_pos_emb with an is_neox_style switch and update GLM-MoE-DSA call sites to use interleaved RoPE (is_neox_style=False).
  • Remove F.relu() from the DSA indexer scoring path.
  • Adjust how the DSA index mask is combined with the attention/causal mask.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py Implements interleaved-vs-NeoX RoPE logic, updates RoPE call sites, removes ReLU, and modifies mask combination logic.
src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py Regenerated modeling file reflecting the same RoPE/scoring/mask-combination changes from the modular source.

Signed-off-by: JaredforReal <w13431838023@gmail.com>
@JaredforReal JaredforReal changed the title [Fix] GLM 5 set apply_rotary_pos_emb to is_neox_style=False && remove F.relu() [WIP][Fix] GLM 5 set apply_rotary_pos_emb to is_neox_style=False && remove F.relu() Mar 26, 2026
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty! We can probably add as a super slow test what you shared in the snipet

@JaredforReal
Copy link
Copy Markdown
Contributor Author

@ArthurZucker Yeah, the test is really super slow!!
I'm doing more tests on PPL, not in a hurry to merge it

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Okay, waiting for your 🟢 to merge. Can you add in integration tests please? 🤗

@IlyasMoutawwakil IlyasMoutawwakil mentioned this pull request Mar 26, 2026
6 tasks
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: glm_moe_dsa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PPL异常

4 participants