Fix GraniteConfig type hints to accept int for multiplier fields#45019
Conversation
|
cc @tomaarsen @zucchini-nlp from #44877 |
|
This strikes me as a good fix, although I haven't tried it locally. I'm personally a fan of expanding the config type hints a bit with float/int rather than requiring model authors to update their ints to floats.
|
| logits_scaling: float = 1.0 | ||
| residual_multiplier: float = 1.0 | ||
| attention_multiplier: float = 1.0 | ||
| embedding_multiplier: float | int = 1.0 |
There was a problem hiding this comment.
there are more Granite models with xxx_multiplier, xan you update all of them?
There was a problem hiding this comment.
Updated GraniteMoe and GraniteMoeShared. GraniteMoeHybrid already supports int | float | None.
Let me know if you spot any other fixes or anything else to add!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: granite, granitemoe, granitemoeshared |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks, I will update other similar fields in separate PR
For context, we decided with Lucain that we don't want to force cast as pydantic does because we want the validation be a lightweight dependency. So instead I will check what other fields could be potentially both and update annotations (already did that for all dropouts)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…gingface#45019) * Fix GraniteConfig type hints to accept int for multiplier fields Fixes huggingface#44877 * Also update granitemoe and granitemoeshared multiplier type hints
What does this PR do?
Fixes #44877
Loading
ibm-granite/granite-4.0-1b-speechfails withStrictDataclassFieldValidationErrorbecause its config.json storesembedding_multiplierandlogits_scalingas integers (e.g.12,8),but
GraniteConfigdeclares them asfloat.The fix updates the type hints for all four multiplier fields to
float | int,matching the pattern already used for
attention_dropoutin the same class.