[API] feat: add reasoning recovery mechanism#1541
Conversation
Summary of ChangesHello @AlpinDale, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements a novel reasoning recovery mechanism designed to improve the robustness and user experience of models when they encounter uncertainty during their reasoning process. Instead of abruptly stopping, the system will now attempt to guide the model through self-correction by inserting specific phrases and allowing it to reconsider its output. This feature is highly configurable, enabling fine-tuning of recovery attempts and messaging, ultimately leading to more graceful handling of uncertain scenarios and potentially more helpful responses. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a reasoning recovery mechanism to handle model uncertainty. The implementation adds new sampling parameters, a state management class for recovery, and integrates the logic into the output processing pipeline. The changes are well-structured, but there are a few key issues. A critical issue was found where the recovery phrases are not correctly fed back into the model for generation, defeating the purpose of the feature. Additionally, a missing validation check could lead to silent failures, and there are opportunities to improve code clarity and style in the new parameter validation and state management class.
| recovery_tokens = tokenizer.encode( | ||
| recovery_phrase, | ||
| add_special_tokens=False | ||
| ) | ||
| req_state.detokenizer.token_ids.extend(recovery_tokens) |
There was a problem hiding this comment.
The recovery_tokens are appended to req_state.detokenizer.token_ids, which seems to only affect the final output text. For the reasoning recovery to work, these tokens must be appended to the model's input sequence for the next generation step. As it is, the model will continue generating from its state before the recovery phrase, which defeats the purpose of the recovery mechanism. You'll likely need a mechanism to communicate these tokens back to the scheduler to append them to the Sequence object for this request.
| if not self.enable_deepconf: | ||
| raise ValueError( | ||
| "enable_reasoning_recovery requires enable_deepconf to be " | ||
| "True") |
There was a problem hiding this comment.
enable_reasoning_recovery depends on the ability to tokenize the recovery phrases and inject them. This requires detokenize=True. You should add a validation check to ensure this to prevent silent failures.
if not self.enable_deepconf:
raise ValueError(
"enable_reasoning_recovery requires enable_deepconf to be "
"True")
if not self.detokenize:
raise ValueError(
"enable_reasoning_recovery requires detokenize to be True")| if ( | ||
| self.recovery_phrases is not None and not | ||
| isinstance(self.recovery_phrases, list) | ||
| ): | ||
| raise ValueError( | ||
| "recovery_phrases must be a list of strings, got " | ||
| f"{type(self.recovery_phrases)}.") | ||
| if ( | ||
| self.recovery_phrases is not None and not all( | ||
| isinstance(p, str) for p in self.recovery_phrases) | ||
| ): | ||
| raise ValueError( | ||
| "recovery_phrases must contain only strings.") |
There was a problem hiding this comment.
The validation for recovery_phrases can be made more concise and readable by nesting the checks inside a single if self.recovery_phrases is not None: block. This avoids repeating the is not None check and improves maintainability.
if self.recovery_phrases is not None:
if not isinstance(self.recovery_phrases, list):
raise ValueError(
"recovery_phrases must be a list of strings, got "
f"{type(self.recovery_phrases)}.")
if not all(isinstance(p, str) for p in self.recovery_phrases):
raise ValueError(
"recovery_phrases must contain only strings.")| def __post_init__(self): | ||
| """Initialize default values after dataclass creation.""" | ||
| if self.recovery_phrases is None: | ||
| self.recovery_phrases = DEFAULT_RECOVERY_PHRASES.copy() | ||
| if not self.final_admission: | ||
| self.final_admission = DEFAULT_FINAL_ADMISSION | ||
| if self.original_prompt_tokens is None: | ||
| self.original_prompt_tokens = [] | ||
| if self.original_output_tokens is None: | ||
| self.original_output_tokens = [] | ||
| if self.recovery_point_tokens is None: | ||
| self.recovery_point_tokens = [] |
There was a problem hiding this comment.
You can make this __post_init__ more concise and idiomatic by using dataclasses.field(default_factory=...) for mutable default values like lists. This avoids the is None checks in __post_init__.
For example, you can change the field definitions like this:
from dataclasses import field
...
recovery_phrases: list[str] = field(default_factory=lambda: DEFAULT_RECOVERY_PHRASES.copy())
original_prompt_tokens: list[int] = field(default_factory=list)
original_output_tokens: list[int] = field(default_factory=list)
recovery_point_tokens: list[int] = field(default_factory=list)Then, __post_init__ can be simplified as shown in the suggestion.
| def __post_init__(self): | |
| """Initialize default values after dataclass creation.""" | |
| if self.recovery_phrases is None: | |
| self.recovery_phrases = DEFAULT_RECOVERY_PHRASES.copy() | |
| if not self.final_admission: | |
| self.final_admission = DEFAULT_FINAL_ADMISSION | |
| if self.original_prompt_tokens is None: | |
| self.original_prompt_tokens = [] | |
| if self.original_output_tokens is None: | |
| self.original_output_tokens = [] | |
| if self.recovery_point_tokens is None: | |
| self.recovery_point_tokens = [] | |
| def __post_init__(self): | |
| """Initialize default values after dataclass creation.""" | |
| if not self.final_admission: | |
| self.final_admission = DEFAULT_FINAL_ADMISSION |
Implementation attempt for recovering from model uncertainty in reasoning trace but interrupting the model, inserting a recovery phrase (e.g. "–wait, I need to reconsider this...") and continuing. It'll attempt this
max_recovery_attemptstimes in the reasoning trace. If the confidence builds up above the threshold (as dictated by DeepConf), it will continue naturally into the final response. If confidence is not raised, it'll insert a final admission phrase, and the final output will be the model telling the user it doesn't know the answer.For now, testing is difficult because I can't think of a good question that would make the model really uncertain.
To test: